forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_gpt2_fp32.cu
1741 lines (1557 loc) · 75.8 KB
/
train_gpt2_fp32.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
GPT-2 Transformer Neural Net trained in raw CUDA
Non-trivial notes to be aware of:
We are being clever in the backward pass to conserve memory.
In particular, all parameters use a += in the backward pass, so we
can later do gradient accumulation. But all activations have = instead of +=
because these are faster (just read, no write). This is okay for all activations
except for those in the residual stream, where the gradients have to add. We make
sure that those parts work out ok and that we do a += as necessary. E.g.,
the layernorms are connected to the residuals so we += in layernorm backward.
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <assert.h>
#include <float.h>
#include <string.h>
#include <unistd.h>
// GPU / CUDA related
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cublasLt.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
// our own utilities
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
#include "utils.h"
// defines: tokenizer_init, tokenizer_decode, tokenizer_free
#include "tokenizer.h"
// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free
#include "dataloader.h"
// ----------------------------------------------------------------------------
// CUDA utils
// convenience macro for calculating grid/block dimensions for kernels
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
// CUDA error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
// cuBLAS error checking
void cublasCheck(cublasStatus_t status, const char *file, int line)
{
if (status != CUBLAS_STATUS_SUCCESS) {
printf("[cuBLAS ERROR]: %d %s %d\n", status, file, line);
exit(EXIT_FAILURE);
}
}
#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }
// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK
static size_t cublaslt_workspace_size = 32 * 1024 * 1024;
static void* cublaslt_workspace = NULL;
static cublasComputeType_t cublas_compute_type;
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
namespace cg = cooperative_groups;
// ----------------------------------------------------------------------------
// all the kernels
__device__ inline float4 add_float4(const float4& a, const float4& b) {
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
// use of float4 leads to using 128-bit LDG / STG instructions in SASS,
// very helpful in memory-bound kernels like encoder_forward
__global__ void encoder_forward_kernel3(float4* out,
const int* inp, const float4* wte, const float4* wpe,
int B, int T, int C) {
int C4 = C / 4;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C4;
if (idx < N) {
int bt = idx / C4;
int b = bt / T;
int t = bt % T;
int c4 = idx % C4;
int ix = inp[b * T + t];
out[b * T * C4 + t * C4 + c4] = add_float4(wte[ix * C4 + c4], wpe[t * C4 + c4]);
}
}
// really bad naive kernel with atomicAdd
__global__ void encoder_backward_kernel(float* dwte, float* dwpe,
const float* dout, const int* inp,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C;
if (idx < N) {
int bt = idx / C;
int b = bt / T;
int t = bt % T;
int c = idx % C;
int ix = inp[b * T + t];
const float* dout_btc = dout + b * T * C + t * C + c;
float* dwte_ix = dwte + ix * C + c;
float* dwpe_tc = dwpe + t * C + c;
atomicAdd(dwte_ix, *dout_btc);
atomicAdd(dwpe_tc, *dout_btc);
}
}
__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
const float* __restrict__ inp, const float* __restrict__ weight,
const float* __restrict__ bias, int N, int C) {
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if(idx >= N) {
return;
}
// the row of input that this group of threads is responsible for
const float* x = inp + idx * C;
// mean
float sum = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
sum += x[i];
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
float m = sum / C;
if(warp.thread_rank() == 0 && mean != nullptr) {
__stcs(mean + idx, m);
}
// rstd
sum = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float diff = x[i] - m;
sum += diff * diff;
}
sum = cg::reduce(warp, sum, cg::plus<float>{});
float s = rsqrtf(sum / C + 1e-5f);
if(warp.thread_rank() == 0 && rstd != nullptr) {
__stcs(rstd + idx, s);
}
// final normalization and scaling by weight/bias
float* o = out + idx * C;
for (int c = warp.thread_rank(); c < C; c += warp.size()) {
// load and store using the .cs "streaming" hint to the compiler,
// indicating that this data will not be reused soon, and can be streamed through the caches
// this allows the threads to get more cache-hits for the (shared) weight and bias parameters
float n = s * (__ldcs(x+c) - m);
__stcs(o+c, n * weight[c] + bias[c]);
}
}
__global__ void permute_kernel(float* q, float* k, float* v,
const float* inp,
int B, int N, int NH, int d) {
// okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)
// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
q[idx] = __ldcs(&inp[inp_idx]);
k[idx] = __ldcs(&inp[inp_idx + NH * d]);
v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]);
}
}
__global__ void permute_kernel_backward(float* dinp,
const float* dq, const float* dk, const float* dv,
int B, int N, int NH, int d) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
dinp[inp_idx] = dq[idx];
dinp[inp_idx + NH * d] = dk[idx];
dinp[inp_idx + 2 * (NH * d)] = dv[idx];
}
}
__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) {
// out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// out[b][n][nh_][d_] <- inp[b][nh_][n][d_]
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
out[other_idx] = __ldcs(&inp[idx]);
}
}
__global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
dinp[idx] = dout[other_idx];
}
}
__device__ float& vec_at(float4& vec, int index) {
return reinterpret_cast<float*>(&vec)[index];
}
__device__ float vec_at(const float4& vec, int index) {
return reinterpret_cast<const float*>(&vec)[index];
}
__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {
// inp, out shape: (N, T, T), where N = B * NH
// fuses the multiplication by scale inside attention
// directly autoregressive, so we only compute the lower triangular part
// uses the online softmax algorithm
assert(T % 4 == 0);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
// micro-optimization: we iterate backwards so that
// after the softmax backward operation completes, the cache retains the
// part of the matrix close to the upper left corner, which benefits the
// matmul operation that immediately follows.
// int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); // forward order
int idx = (gridDim.x - blockIdx.x - 1) * warp.meta_group_size() + warp.meta_group_rank(); // backward order
if(idx >= N * T) {
return;
}
int own_pos = idx % T;
int pos_by_4 = own_pos / 4;
// one row of inp, i.e. inp[idx, :] of shape (T,)
const float* x = inp + idx * T;
// not INF, so we don't get NaNs accidentally when subtracting two values.
float maxval = -FLT_MAX;
float sumval = 0.0f;
const float4* x_vec = reinterpret_cast<const float4*>(x);
for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {
float4 v = x_vec[i];
float old_maxval = maxval;
for(int k = 0; k < 4; ++k) {
maxval = fmaxf(maxval, vec_at(v, k));
}
sumval *= expf(inv_temperature * (old_maxval - maxval));
for(int k = 0; k < 4; ++k) {
sumval += expf(inv_temperature * (vec_at(v, k) - maxval));
}
}
if(4*pos_by_4 + warp.thread_rank() <= own_pos) {
float old_maxval = maxval;
maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);
sumval *= expf(inv_temperature * (old_maxval - maxval));
sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));
}
float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});
sumval *= expf(inv_temperature * (maxval - global_maxval));
float sum = cg::reduce(warp, sumval, cg::plus<float>{});
float norm = 1.f / sum;
// divide the whole row by the sum
for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {
// recalculation is faster than doing the round-trip through memory.
float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));
__stcs(out + idx * T + i, ev * norm);
}
}
__global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
out[idx] = __ldcs(&inp1[idx]) + __ldcs(&inp2[idx]);
}
}
#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
__global__ void gelu_forward_kernel(float* out, const float* inp, int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) {
float xi = inp[i];
float cube = 0.044715f * xi * xi * xi;
out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));
}
}
__global__ void gelu_backward_kernel(float* dinp, const float* inp, const float* dout, const int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) {
float x = inp[i];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
dinp[i] = local_grad * dout[i];
}
}
// this kernel performs a column-wise reduction over dout, in PyTorch equivalent to:
// dbias = dout.sum((0,1))
// the idea is to employ one block to reduce along several columns,
// where each block has a width of 32 columns to ensure coalesced access.
// at the end we accumulate the reductions performed by the warps in each block via shared memory
__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {
// this kernel is launched with 1D grid_dim of OC/32
// for example let's say block_size is 128
extern __shared__ float smem[]; // of size block_size (128)
const int warp_id = threadIdx.x / warpSize; // warp index in the block, 0,1,2,3
const int lane_id = threadIdx.x % warpSize; // thread index in the warp, 0,1,2,...,31
const int tl = blockIdx.x * warpSize; // pointer to the start column for this block
const int vstep = blockDim.x / warpSize; // number of warps in a block, e.g. 4
// pointer to the start of the column for one lane of threads
// so e.g. 4 threads (of the same lane_id) will reduce this one column
const float* dout_col = dout + tl + lane_id;
// column reductions by looping through the rows
// each of the 4 threads offsets by its warp_id and then skips by vstep
// together these 4 threads cover all B*T rows of this (lane_id) column
// importantly, consecutive threads (in threadId) are processing adjacent columns,
// leading to a coalesced memory access pattern
float dout_sum = 0.0f;
for (int row = warp_id; row < B * T; row += vstep) {
dout_sum += dout_col[row * OC];
}
smem[lane_id + warp_id * warpSize] = dout_sum;
__syncthreads();
// warp_id 0 reduces the shared memory column-wise, linearly
dout_sum = 0.0f;
if (warp_id == 0) {
for (int j = 0; j < vstep; j++) {
dout_sum += smem[lane_id + j * warpSize];
}
dbias[tl + lane_id] += dout_sum;
}
}
// uses shared memory instead for the reduces
__global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
int N = B * T;
if(idx >= N) { return; } // thread guards
int b = idx / T;
int t = idx % T;
const float* dout_bt = dout + b * T * C + t * C;
const float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = mean[b * T + t];
const float rstd_bt = rstd[b * T + t];
// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
// init shared memory to zero
#pragma unroll
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
__syncthreads();
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = warp.thread_rank(); i < C; i += warp.size()) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// gradient contribution to bias
atomicAdd(&dbias_shared[i], dout_bt[i]);
// gradient contribution to weight
atomicAdd(&dweight_shared[i], norm_bti * dout_bt[i]);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] += dval;
}
__syncthreads();
// write to global memory
for(int i = threadIdx.x; i < C; i+= blockDim.x){
atomicAdd(&dbias[i], dbias_shared[i]);
atomicAdd(&dweight[i], dweight_shared[i]);
}
}
__global__ void softmax_autoregressive_backward_kernel(float* dpreatt, const float* datt, const float* att,
int B, int T, int C, float scale) {
constexpr const int BlockSize = 256;
constexpr int T_per_block = 4;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
__shared__ float block_acc[32];
int idx = blockIdx.y;
// go through blocks in reverse order, so the slowest block starts first
int t0 = T - 1 - T_per_block*blockIdx.x;
att += idx * T * T;
datt += idx * T * T;
dpreatt += idx * T * T;
if (warp.meta_group_rank() == 0) {
block_acc[warp.thread_rank()] = 0;
}
for(int to = 0; to < T_per_block; ++to) {
int t = t0 - to;
if(t < 0) return;
const float* att_bth = att + t * T;
const float* datt_bth = datt + t * T;
float* dpreatt_bth = dpreatt + t * T;
float local_sum = 0;
for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) {
local_sum += att_bth[t2] * datt_bth[t2];
}
block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus<float>{});
block.sync();
local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus<float>{});
for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) {
// don't touch the cache. Some parts will still be here from the previous loop, and
// we want to exploit those.
float acc = __ldcs(att_bth + t3) * (__ldcs(datt_bth + t3) - local_sum);
__stcs(dpreatt_bth + t3, scale * acc);
}
}
}
// Implements linear interpolation using only two floating-point operations (as opposed to three in a naive implementation).
// Reference: https://developer.nvidia.com/blog/lerp-faster-cuda
__device__ inline float lerp(float start, float end, float weight) {
return fma(weight, end, fma(-weight, start, start));
}
__global__ void adamw_kernel2(float* params_memory, float* grads_memory, float* m_memory, float* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= num_parameters) return; // guard
float grad = grads_memory[i];
float m = m_memory[i];
float v = v_memory[i];
// update the first moment (momentum)
m = lerp(grad, m, beta1);
m_memory[i] = m;
// update the second moment (RMSprop)
v = lerp(grad * grad, v, beta2);
v_memory[i] = v;
m /= beta1_correction; // m_hat
v /= beta2_correction; // v_hat
params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]);
}
struct SoftmaxParams {
float Scale;
float Offset;
};
__device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp,
int idx, const float* inp, int V, int P) {
// same but not float4
// one row of inp, i.e. inp[idx, :] of shape (V,)
const float* x = inp + idx * P;
float thread_maxval = -INFINITY;
float thread_sumval = 0.0f;
// do the loop in reverse to maximise probability of L2 cache hits
// so even small L2s get some hits on the 2nd read of the same thread
for (int i = V + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {
float v = x[i];
float old_maxval = thread_maxval;
thread_maxval = fmaxf(thread_maxval, v);
thread_sumval *= expf((old_maxval - thread_maxval));
thread_sumval += expf(v - thread_maxval);
}
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
// this results in much cleaner assembly than a multi-warp cg::reduce
__shared__ float shared_maxval[32];
__shared__ float shared_sumval[32];
int num_warps = blockDim.x / 32;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
// reduce maxval within each warp
float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});
// thread 0 in each warp writes to shared memory
if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; }
__syncthreads();
// each thread now loads the maxval across previous warps
// if the thread is "out of range" of data, use -FLT_MAX as the maxval
warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX;
// now reduce the maxval among the warp threads
float block_maxval = cg::reduce(warp, warp_maxval, cg::greater<float>{});
// each thread uses maxval to scale sumval to avoid numerical instability / overflow
thread_sumval *= expf(thread_maxval - block_maxval);
// (warp-level) reduce sumval, thread 0 in each warp saves result in shared memory
float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus<float>{});
if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; }
__syncthreads();
// same strategy, now reduce sumval across warps
warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f;
float block_sumval = cg::reduce(warp, warp_sumval, cg::plus<float>{});
// return the softmax parameters
return SoftmaxParams{1.f / block_sumval, block_maxval};
}
// same as 2 but not using float4 (see dev/cuda/classifier_fused.cu)
// will _update_ logits to logit gradients
__global__ void fused_classifier_kernel3(float* logits, float* losses, float* probs,
const float* dlosses, const int* targets,
int B, int T, int V, int P) {
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x;
int ix = targets[idx];
// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
SoftmaxParams sp = prepare_softmax_blockwide_nofloat4(warp, idx, logits, V, P);
// calculate the probability needed for the loss and update (single-threaded)
if(threadIdx.x == 0) {
float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;
losses[idx] = -logf(prob);
}
// very sensible default for dlosses is 1/(B*T), which is the uniform loss
float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T);
// calculate the gradients directly, saves bandwidth from probs during training
// but also supports writing probs for inference-only and debugging
const float* logits_vec = logits + idx * P;
for (int i = threadIdx.x; i < V; i += blockDim.x) {
// this is the 2nd read of logits after the one in prepare_softmax2
// this data will never be needed again, so we reduce cache persistence
float v = __ldcs(&logits_vec[i]);
float prob = expf(v - sp.Offset) * sp.Scale;
if (probs != NULL) {
probs[idx * P + i] = prob;
}
float indicator = (i == ix) ? 1.0f : 0.0f;
logits[idx * P + i] = (prob - indicator) * dloss;
}
}
// ----------------------------------------------------------------------------
// kernel launchers
void encoder_forward(float* out,
const int* inp, const float* wte, const float* wpe,
int B, int T, int C) {
assert(C % 4 == 0);
const int block_size = 512;
const int N = B * T * C;
const int grid_size = CEIL_DIV(N / 4, block_size);
encoder_forward_kernel3<<<grid_size, block_size>>>((float4*) out, inp, (float4*) wte, (float4*) wpe, B, T, C);
cudaCheck(cudaGetLastError());
}
void encoder_backward(float* dwte, float* dwpe,
const float* dout, const int* inp,
int B, int T, int C) {
const int N = B * T * C;
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size);
encoder_backward_kernel<<<grid_size, block_size>>>(dwte, dwpe, dout, inp, B, T, C);
cudaCheck(cudaGetLastError());
}
void layernorm_forward(float* out, float* mean, float* rstd,
float* inp, float* weight, float* bias,
int B, int T, int C) {
const int block_size = 512;
const int N = B * T;
const int grid_size = CEIL_DIV(N * 32, block_size);
layernorm_forward_kernel3<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);
cudaCheck(cudaGetLastError());
}
// uses cuBLASLt to fuse the bias and gelu. does not work with OC = 50257 (last layer)
// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu
void matmul_forward_cublaslt(float* out,
float* inp, float* weight, float* bias,
int B, int T, int C, int OC) {
int has_bias = (bias != NULL);
// check bias alignment
if(((uintptr_t)bias % 16) != 0) {
printf("Bias pointer is not aligned (cuBLASLt requirement)!\n");
exit(EXIT_FAILURE);
}
int returnedResults = 0;
cublasLtMatmulDesc_t operationDesc;
cublasLtMatmulPreference_t preference;
cublasLtMatrixLayout_t weightLayout;
cublasLtMatrixLayout_t inputLayout;
cublasLtMatrixLayout_t outputLayout;
cublasLtMatrixLayout_t biasLayout;
cublasLtMatmulHeuristicResult_t heuristic;
// create the operation descriptor
cublasOperation_t opNoTranspose = CUBLAS_OP_N;
cublasOperation_t opTranspose = CUBLAS_OP_T;
cublasLtEpilogue_t epilogueBias = CUBLASLT_EPILOGUE_BIAS;
cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute_type, CUDA_R_32F));
cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose)));
cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose)));
if(has_bias) {
cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias,
sizeof(epilogueBias)));
}
cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
// define matrix layouts
cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C));
cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C));
cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, OC, B*T, OC));
cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUDA_R_32F, OC, 1, OC));
// create a preference handle with specified max workspace
cublasCheck(cublasLtMatmulPreferenceCreate(&preference));
cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&cublaslt_workspace_size, sizeof(cublaslt_workspace_size)));
// find a suitable algorithm
cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc,
weightLayout, inputLayout, outputLayout, outputLayout,
preference, 1, &heuristic, &returnedResults));
if (returnedResults == 0) {
printf("No cuBLASLt algorithm: B: %d, T: %d, C: %d, OC: %d, bias: %d\n", B, T, C, OC, has_bias);
exit(EXIT_FAILURE);
}
// call the matmul
const float alpha = 1.0f, beta = 0.0f;
cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc,
&alpha, weight, weightLayout, inp, inputLayout, &beta,
out, outputLayout, out, outputLayout, &heuristic.algo,
cublaslt_workspace, cublaslt_workspace_size, 0));
// cleanups
cublasCheck(cublasLtMatmulPreferenceDestroy(preference));
cublasCheck(cublasLtMatmulDescDestroy(operationDesc));
cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout));
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
}
void attention_forward(float* out, float* qkvr, float* att,
float* inp,
int B, int T, int C, int NH) {
// Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer.
// Its contents will be overwritten by this function.
const int block_size = 256;
const int softmax_block_size = 256;
// inp is (B, T, 3C) QKV
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int HS = C / NH; // head size
// permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)
float *q, *k, *v;
q = qkvr + 0 * B * T * C;
k = qkvr + 1 * B * T * C;
v = qkvr + 2 * B * T * C;
int total_threads = B * NH * T * HS;
int num_blocks = CEIL_DIV(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
cudaCheck(cudaGetLastError());
// batched matrix multiply with cuBLAS
const float alpha = 1.0f;
const float beta = 0.0f;
float* preatt = inp;
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH));
// multiply all elements of preatt elementwise by scale
float scale = 1.0 / sqrtf(HS);
int grid_size = CEIL_DIV(B * NH * T * 32, softmax_block_size);
softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);
cudaCheck(cudaGetLastError());
// new approach: first cuBLAS another batched matmul
float* vaccum = inp;
// y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH));
// now unpermute
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = CEIL_DIV(B * T * C, block_size);
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}
void residual_forward(float* out, float* inp1, float* inp2, int N) {
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}
void gelu_forward(float* out, const float* inp, int N) {
const int block_size = 128;
const int grid_size = CEIL_DIV(N, block_size);
gelu_forward_kernel<<<grid_size, block_size>>>(out, inp, N);
cudaCheck(cudaGetLastError());
}
void gelu_backward(float* dinp, const float* inp, const float* dout, const int N) {
const int block_size = 128;
const int grid_size = CEIL_DIV(N, block_size);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout, N);
cudaCheck(cudaGetLastError());
}
void matmul_backward(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight,
int B, int T, int C, int OC) {
float one = 1.0f;
float zero = 0.0f;
// backward to input, uses = in the backward pass (set the gradient)
cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one, weight, C, dout, OC, &zero, dinp, C));
// backward to weight, uses += in the backward pass (accumulate the gradient)
cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C));
// backward to bias, if given, does a +=
if (dbias != NULL) {
const int block_size = 1024;
const int grid_size = OC / 32; // for now, OC must be divisible by 32 for this kernel to work
matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
cudaCheck(cudaGetLastError());
}
}
void layernorm_backward(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
const int block_size = 512;
const int N = B * T;
const int grid_size = CEIL_DIV(32*N, block_size);
size_t shared_mem_size = 2 * C * sizeof(float);
layernorm_backward_kernel2<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);
cudaCheck(cudaGetLastError());
}
// the sequence of transformations in this compound op is:
// inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C)
void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* scratch,
const float* dout,
const float* qkvr, const float* att,
int B, int T, int C, int NH) {
const int block_size = 256;
int HS = C / NH; // head size
const float one = 1.0f;
const float zero = 0.0f; // note beta = 1.0f so that we accumulate gradients (+=)
// unpack convenience pointers into q, k, v
const float *q, *k, *v;
q = qkvr + 0 * B * T * C;
k = qkvr + 1 * B * T * C;
v = qkvr + 2 * B * T * C;
float *dq, *dk, *dv;
dq = dqkvr + 0 * B * T * C;
dk = dqkvr + 1 * B * T * C;
dv = dqkvr + 2 * B * T * C;
// backward through the unpermute operation
int num_blocks = CEIL_DIV(B * T * C, block_size);
unpermute_kernel_backward<<<num_blocks, block_size>>>(scratch, dout, B, T, NH, HS);
cudaCheck(cudaGetLastError());
// backward into datt
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &one, v, HS, T * HS, scratch, HS, T * HS, &zero, datt, T, T * T, B * NH));
// backward into dv
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, scratch, HS, T * HS, att, T, T * T, &zero, dv, HS, T * HS, B * NH));
// backward into preatt
int hs = C / NH; // head size
float scale = 1.0f / sqrtf(hs);
softmax_autoregressive_backward_kernel<<<dim3(T / 4, B * NH), 256>>>(dpreatt, datt, att, B, T, C, scale);
cudaCheck(cudaGetLastError());
// backward into q
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &one, k, HS, T * HS, dpreatt, T, T * T, &zero, dq, HS, T * HS, B * NH));
// backward into k
cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, q, HS, T * HS, dpreatt, T, T * T, &zero, dk, HS, T * HS, B * NH));
// backward into inp
num_blocks = CEIL_DIV(B * NH * T * HS, block_size);
permute_kernel_backward<<<num_blocks, block_size>>>(dinp, dq, dk, dv, B, T, NH, HS);
cudaCheck(cudaGetLastError());
}
// replaces logits with logit gradients
void fused_classifier3(float* logits, float* losses,
const float* dlosses, const int* targets,
int B, int T, int V, int P) {
const int block_size = 1024;
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel3<<<grid_size, block_size>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);
cudaCheck(cudaGetLastError());
}
// ----------------------------------------------------------------------------
// GPT-2 model definition
typedef struct {
int max_seq_len; // max sequence length, e.g. 1024
int vocab_size; // vocab size, e.g. 50257
int padded_vocab_size; // padded to e.g. %128==0, 50304
int num_layers; // number of layers, e.g. 12
int num_heads; // number of heads in attention, e.g. 12
int channels; // number of channels, e.g. 768
} GPT2Config;
// the parameters of the model
#define NUM_PARAMETER_TENSORS 16
typedef struct {
float* wte; // (V, C)
float* wpe; // (maxT, C)
float* ln1w; // (L, C)
float* ln1b; // (L, C)
float* qkvw; // (L, 3*C, C)
float* qkvb; // (L, 3*C)
float* attprojw; // (L, C, C)
float* attprojb; // (L, C)
float* ln2w; // (L, C)
float* ln2b; // (L, C)
float* fcw; // (L, 4*C, C)
float* fcb; // (L, 4*C)
float* fcprojw; // (L, C, 4*C)
float* fcprojb; // (L, C)
float* lnfw; // (C)
float* lnfb; // (C)
} ParameterTensors;
void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) {
int Vp = config.padded_vocab_size;
int C = config.channels;
int maxT = config.max_seq_len;
int L = config.num_layers;
param_sizes[0] = Vp * C; // wte
param_sizes[1] = maxT * C; // wpe
param_sizes[2] = L * C; // ln1w
param_sizes[3] = L * C; // ln1b
param_sizes[4] = L * (3 * C) * C; // qkvw
param_sizes[5] = L * (3 * C); // qkvb
param_sizes[6] = L * C * C; // attprojw
param_sizes[7] = L * C; // attprojb
param_sizes[8] = L * C; // ln2w
param_sizes[9] = L * C; // ln2b
param_sizes[10] = L * (4 * C) * C; // fcw
param_sizes[11] = L * (4 * C); // fcb
param_sizes[12] = L * C * (4 * C); // fcprojw
param_sizes[13] = L * C; // fcprojb
param_sizes[14] = C; // lnfw
param_sizes[15] = C; // lnfb
}
// allocate memory for the parameters and point the individual tensors to the right places
float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes, int on_device) {
// on_device: 0 = CPU, 1 = GPU
// calculate the number of parameters
size_t num_parameters = 0;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += param_sizes[i];
}
// malloc all parameters all at once on the device
float* params_memory;
if (on_device) {
cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters * sizeof(float)));
} else {
params_memory = (float*)mallocCheck(num_parameters * sizeof(float));
}
// assign all the tensors their place in the array
float** ptrs[] = {
¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb,
¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb,
¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb
};
float* params_memory_iterator = params_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}
return params_memory;
}
#define NUM_ACTIVATION_TENSORS 21
typedef struct {
float* encoded; // (B, T, C)
float* ln1; // (L, B, T, C)
float* ln1_mean; // (L, B, T)
float* ln1_rstd; // (L, B, T)
float* atty; // (L, B, T, C)
float* att; // (L, B, NH, T, T)
float* attproj; // (L, B, T, C)
float* residual2; // (L, B, T, C)
float* ln2; // (L, B, T, C)
float* ln2_mean; // (L, B, T)
float* ln2_rstd; // (L, B, T)
float* fch; // (L, B, T, 4*C)
float* fch_gelu; // (L, B, T, 4*C)
float* fcproj; // (L, B, T, C)
float* residual3; // (L, B, T, C)
float* lnf; // (B, T, C)
float* lnf_mean; // (B, T)
float* lnf_rstd; // (B, T)
float* losses; // (B, T)
// adding these two compared to the CPU .c code, needed for attention kernel as buffers
float* qkvr; // (L, B, T, 3*C)
// in inference mode, this buffer will store the logits
// in training mode, this buffer will contain the *gradients* of the logits.
// during the processing of transformer blocks, we will also use this as a
// general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C),
// (B, NH, T, T), and (B, T, V) shaped tensors.
float* output;
} ActivationTensors;
void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config) {
size_t Vp = config.padded_vocab_size;
size_t L = config.num_layers;
size_t NH = config.num_heads;
size_t C = config.channels;
act_sizes[0] = B * T * C; // encoded
act_sizes[1] = L * B * T * C; // ln1
act_sizes[2] = L * B * T; // ln1_mean
act_sizes[3] = L * B * T; // ln1_rstd
act_sizes[4] = L * B * T * C; // atty
act_sizes[5] = L * B * NH * T * T; // att
act_sizes[6] = L * B * T * C; // attproj
act_sizes[7] = L * B * T * C; // residual2