ae2f_docs
Mhattn.auto.h
1#undef __ae2f_MACRO_GENERATED
2#define __ae2f_MACRO_GENERATED 1
4#define ae2f_Ann_Mhattn_auto_h
5
6#include <ae2f/Ann/Mhattn.h>
7#undef __ae2f_MACRO_GENERATED
8#define __ae2f_MACRO_GENERATED 1
9
10#if !ae2f_MAC_BUILD || !__ae2f_MACRO_GENERATED
11#include <assert.h>
12#undef __ae2f_MACRO_GENERATED
13#define __ae2f_MACRO_GENERATED 1
14#include <stdlib.h>
15#undef __ae2f_MACRO_GENERATED
16#define __ae2f_MACRO_GENERATED 1
17#include <math.h>
18#undef __ae2f_MACRO_GENERATED
19#define __ae2f_MACRO_GENERATED 1
20#endif
21
22#if !__ae2f_MACRO_GENERATED
23#include <ae2f/Macro.h>
24#undef __ae2f_MACRO_GENERATED
25#define __ae2f_MACRO_GENERATED 1
26#endif
27
28/**
29 * @brief
30 *
31 * @param ret_attnw (m_headc, prm_seqlen, prm_seqlen)
32 * @param ret_attno (m_headc, prm_seqlen, kdist)
33 * @param ret_out (prm_seqlen, m_mdldist)
34 *
35 * @param ref_qcache (prm_seqlen, m_mdldist)
36 * @param ref_kcache (prm_seqlen, m_mdldist)
37 * @param ref_vcache (prm_seqlen, m_mdldist)
38 *
39 * @param prm_qry (prm_seqlen, m_mdldist)
40 * @param prm_key (prm_seqlen, m_mdldist)
41 * @param prm_val (prm_seqlen, m_mdldist)
42 *
43 * @param prm_wqry (prm_mdldist, prm_mdldist)
44 * @param prm_wkey (prm_mdldist, prm_mdldist)
45 * @param prm_wval (prm_mdldist, prm_mdldist)
46 * @param prm_wout (prm_mdldist, prm_mdldist)
47 * */
48#define _ae2f_AnnMhattnFwd_imp(
49 /** tparam */
50
51
52 /** param */
53 /* , ae2f_AnnMhattnFwd_t */ ref_mem,
54 /* ae2f_float_t* const */ ref_qcache,
55 /* ae2f_float_t* const */ ref_kcache,
56 /* ae2f_float_t* const */ ref_vcache,
57 /* ae2f_float_t* const */ ref_attnw_cache,
58 /* const ae2f_AnnMhattn_t */ prm_mhattn,
59 /* constae2f_float_t* const */ prm_wqry,
60 /* constae2f_float_t* const */ prm_wkey,
61 /* constae2f_float_t* const */ prm_wval,
62 /* constae2f_float_t* const */ prm_wout,
63 /* constae2f_float_t* const */ prm_qry,
64 /* constae2f_float_t* const */ prm_key,
65 /* constae2f_float_t* const */ prm_val,
66 /* ae2f_opt constae2f_float_t* const */ prm_mask_opt,
67 /* const size_t */ prm_seqlen,
68 /* ae2f_AnnActFwdMHATTN_t */ prm_act,
69 /* ae2f_float_t* const */ ret_out,
70 /* ae2f_float_t* const */ ret_attnw,
71 /* ae2f_float_t* const */ ret_attno \
72)\
73{
74 {
75 assert(ref_kcache && "ae2f_AnnMhattnFwd_imp");
76 assert(ref_qcache && "ae2f_AnnMhattnFwd_imp");
77 assert(ref_vcache && "ae2f_AnnMhattnFwd_imp");
78 assert(ref_attnw_cache && "ae2f_AnnMhattnFwd_imp");
79
80 assert(prm_wqry && "ae2f_AnnMhattnFwd_imp");
81 assert(prm_wkey && "ae2f_AnnMhattnFwd_imp");
82 assert(prm_wval && "ae2f_AnnMhattnFwd_imp");
83 assert(prm_wout && "ae2f_AnnMhattnFwd_imp");
84
85 assert(prm_qry && "ae2f_AnnMhattnFwd_imp");
86 assert(prm_key && "ae2f_AnnMhattnFwd_imp");
87 assert(prm_val && "ae2f_AnnMhattnFwd_imp");
88
89 assert(ret_out && "ae2f_AnnMhattnFwd_imp");
90 assert(ret_attnw && "ae2f_AnnMhattnFwd_imp");
91 assert(ret_attno && "ae2f_AnnMhattnFwd_imp");
92 }
93
94 (ref_mem).m_i = (prm_seqlen) /* i */ * ((prm_mhattn).m_mdldist) /* k */;
95
96 /** \
97 * Matrix multiplication. \
98 * */
99 while((ref_mem).m_i--) {
100 (ref_mem).m_U0.m_S0.m_k = 0;
101 (ref_mem).m_U0.m_S0.m_q = 0;
102 (ref_mem).m_U0.m_S0.m_v = 0;
103
104 for((ref_mem).m_j = (prm_mhattn).m_mdldist /* j */; (ref_mem).m_j--;) {
105 (ref_mem).m_U0.m_S0.m_q
106 += ae2f_AnnMhattnFwdSeqConvOne_imp(
107 prm_qry
108 , prm_wqry
109 , (prm_mhattn).m_mdldist
110 , prm_seqlen
111 , (ref_mem).m_i / (prm_mhattn).m_mdldist
112 , (ref_mem).m_j
113 , (ref_mem).m_i % (prm_mhattn).m_mdldist
114 );
115
116 (ref_mem).m_U0.m_S0.m_k
117 += ae2f_AnnMhattnFwdSeqConvOne_imp(
118 prm_key
119 , prm_wkey
120 , (prm_mhattn).m_mdldist
121 , prm_seqlen
122 , (ref_mem).m_i / (prm_mhattn).m_mdldist
123 , (ref_mem).m_j
124 , (ref_mem).m_i % (prm_mhattn).m_mdldist
125 );
126
127 (ref_mem).m_U0.m_S0.m_v
128 += ae2f_AnnMhattnFwdSeqConvOne_imp(
129 prm_val
130 , prm_wval
131 , (prm_mhattn).m_mdldist
132 , prm_seqlen
133 , (ref_mem).m_i / (prm_mhattn).m_mdldist
134 , (ref_mem).m_j
135 , (ref_mem).m_i % (prm_mhattn).m_mdldist
136 );
137 }
138
139 (ref_qcache)[(ref_mem).m_i] = (ref_mem).m_U0.m_S0.m_q;
140 (ref_kcache)[(ref_mem).m_i] = (ref_mem).m_U0.m_S0.m_k;
141 (ref_vcache)[(ref_mem).m_i] = (ref_mem).m_U0.m_S0.m_v;
142 }
143
144 /** \
145 * @brief \
146 * cache could be read as (m_headc, prm_seqlen, kdist) \
147 * */
148 for((ref_mem).m_i = (prm_mhattn).m_headc; (ref_mem).m_i--;) {
149 (ref_mem).m_j = (prm_seqlen) /* i */ * (prm_seqlen) /* k */;
150 /** \
151 * q * Transpose(k) goes to (prm_seqlen, prm_seqlen) \
152 * */
153 while((ref_mem).m_j--) {
154 (ref_mem).m_U0.m_one = 0;
155 (ref_mem).m_k = ae2f_AnnMhattnKDist(prm_mhattn) /* j */;
156
157 while((ref_mem).m_k--) {
158 /** Since it is transposing, index must be (i, j) * (k, j) */
159 (ref_mem).m_U0.m_one
160 += (ref_qcache)[ae2f_AnnMhattnHeadSplit_imp(
161 prm_mhattn
162 , prm_seqlen
163 , (ref_mem).m_i
164 , (ref_mem).m_j / (prm_seqlen)
165 , (ref_mem).m_k)
166 ]
167 * (ref_kcache)[ae2f_AnnMhattnHeadSplit_imp( /**/
168 prm_mhattn
169 , prm_seqlen
170 , (ref_mem).m_i
171 , (ref_mem).m_j % (prm_seqlen)
172 , (ref_mem).m_k)
173 ];
174 }
175
176 (ref_attnw_cache)[ae2f_AnnUtilIdx3(
177 (ref_mem).m_i
178 , (prm_mhattn).m_headc
179 , 0, (prm_seqlen)
180 , 0, (prm_seqlen)
181 ) + (ref_mem).m_j]
182
183 =
184 ((ref_mem).m_U0.m_one / sqrt(
185 ae2f_static_cast(
186 ae2f_float_t
187 , ae2f_AnnMhattnKDist(prm_mhattn)))
188 )
189 + ((prm_mask_opt) ? (prm_mask_opt)[(ref_mem).m_j] : 0.);
190 }
191
192 (ref_mem).m_j = (prm_seqlen);
193 while((ref_mem).m_j--) {
194 (ref_mem).m_k = (prm_seqlen);
195 while((ref_mem).m_k--) {
196 prm_act(
197 /* ret */
198 &(ret_attnw)[ae2f_AnnUtilIdx3(
199 (ref_mem).m_i
200 , (prm_mhattn).m_headc
201 , (ref_mem).m_j, (prm_seqlen)
202 , (ref_mem).m_k, (prm_seqlen)
203 )],
204 /* prm_retidx */
205 (ref_mem).m_k,
206 /* prm_inp */
207 &(ref_attnw_cache)[ae2f_AnnUtilIdx3(
208 (ref_mem).m_i, (prm_mhattn).m_headc
209 , (ref_mem).m_j, (prm_seqlen)
210 , 0, prm_seqlen
211 )],
212 /* prm_len */
213 (prm_seqlen)
214 );
215 }
216 }
217
218 (ref_mem).m_j = (prm_seqlen) /* i */ * ae2f_AnnMhattnKDist(prm_mhattn) /* k */;
219 while((ref_mem).m_j--)
220 {
221 (ref_mem).m_U0.m_one = 0;
222 for((ref_mem).m_k = (prm_seqlen) /* j */; (ref_mem).m_k--;)
223 {
224 (ref_mem).m_U0.m_one
225 += (ret_attnw)[ae2f_AnnUtilIdx3(
226 (ref_mem).m_i, (prm_mhattn).m_headc
227 , (ref_mem).m_j / ae2f_AnnMhattnKDist(prm_mhattn), prm_seqlen
228 , (ref_mem).m_k, prm_seqlen
229 )] * /**/
230 (ref_vcache)[ae2f_AnnMhattnHeadSplit_imp(
231 prm_mhattn
232 , prm_seqlen
233 , (ref_mem).m_i
234 , (ref_mem).m_k
235 , (ref_mem).m_j % ae2f_AnnMhattnKDist(prm_mhattn)
236 )];
237 }
238
239 (ret_attno)[
240 ae2f_AnnUtilIdx3(
241 (ref_mem).m_i, (prm_mhattn).m_headc
242 , 0, prm_seqlen
243 , 0, ae2f_AnnMhattnKDist(prm_mhattn)
244 ) + (ref_mem).m_j
245 ] = (ref_mem).m_U0.m_one;
246 }
247 }
248
249 (ref_mem).m_i = (prm_seqlen) /* i */ * (prm_mhattn).m_mdldist /* k */;
250 while((ref_mem).m_i--) {
251 (ref_mem).m_j = ((prm_mhattn).m_mdldist) /* j */;
252 (ref_mem).m_U0.m_one = 0;
253
254 while((ref_mem).m_j--) {
255 (ref_mem).m_U0.m_one +=
256 (ret_attno)[
257 ae2f_AnnMhattnHeadConcat_imp(
258 prm_mhattn
259 , prm_seqlen
260 , (ref_mem).m_i / (prm_mhattn).m_mdldist
261 , (ref_mem).m_j
262 )] * /**/
263 (prm_wout)[ae2f_AnnUtilIdx2(
264 (ref_mem).m_j, (prm_mhattn).m_mdldist
265 , (ref_mem).m_i % (prm_mhattn).m_mdldist
266 , (prm_mhattn).m_mdldist)];
267 }
268
269 (ret_out)[(ref_mem).m_i] = (ref_mem).m_U0.m_one;
270 } \
271}
272
273/**
274 * @brief
275 *
276 * @param ref_grad_heads (m_headc, prm_seqlen, kdist)
277 * @param ref_grad_scores (prm_seqlen, prm_seqlen)
278 *
279 * @param ref_qcache (prm_seqlen, m_mdldist)
280 * @param ref_kcache (prm_seqlen, m_mdldist)
281 * @param ref_vcache (prm_seqlen, m_mdldist)
282 *
283 * @param prm_attnw (m_headc, prm_seqlen, prm_seqlen)
284 * @param prm_attno (m_headc, prm_seqlen, kdist)
285 *
286 * @param prm_grad_out (prm_seqlen, m_mdldist)
287 *
288 * @param prm_qry (prm_seqlen, m_mdldist)
289 * @param prm_key (prm_seqlen, m_mdldist)
290 * @param prm_val (prm_seqlen, m_mdldist)
291 *
292 * @param prm_wqry (prm_mdldist, prm_mdldist)
293 * @param prm_wkey (prm_mdldist, prm_mdldist)
294 * @param prm_wval (prm_mdldist, prm_mdldist)
295 * @param prm_wout (prm_mdldist, prm_mdldist)
296 *
297 * @param ret_grad_wout (prm_mdldist, prm_mdldist)
298 * */
299#define _ae2f_AnnMhattnBwd_imp(
300 /** tparam */
301
302
303 /** param */
304 /* ,ae2f_float_t* const */ ref_qcache,
305 /* ae2f_float_t* const */ ref_kcache,
306 /* ae2f_float_t* const */ ref_vcache,
307 /* ae2f_float_t* const */ ref_grad_heads,
308 /* ae2f_float_t* const */ ref_grad_scores,
309 /* const ae2f_AnnMhattn_t */ prm_mhattn,
310 /* constae2f_float_t* const */ prm_grad_out,
311 /* constae2f_float_t* const */ prm_wqry,
312 /* constae2f_float_t* const */ prm_wkey,
313 /* constae2f_float_t* const */ prm_wval,
314 /* constae2f_float_t* const */ prm_wout,
315 /* constae2f_float_t* const */ prm_qry,
316 /* constae2f_float_t* const */ prm_key,
317 /* constae2f_float_t* const */ prm_val,
318 /* const size_t */ prm_seqlen,
319 /* const ae2f_float_t */ prm_lr,
320 /* constae2f_float_t* const */ prm_attnw,
321 /* constae2f_float_t* const */ prm_attno,
322 /* ae2f_AnnActBwdMHATTN_t */ prm_actderiv,
323 /* ae2f_float_t* const */ ret_grad_wqry,
324 /* ae2f_float_t* const */ ret_grad_wkey,
325 /* ae2f_float_t* const */ ret_grad_wval,
326 /* ae2f_float_t* const */ ret_grad_wout \
327)\
328{
329 ae2f_structdef(struct, ae2f_AnnMhattnBwdREG_t) {
330 size_t m_i, m_j, m_hidx;
331
332 union ae2f_AnnMhattnBwdU0REG_t {
333 struct ae2f_AnnMhattnBwdU0S0REG_t {
334 ae2f_float_t m_q;
335 ae2f_float_t m_k;
336 ae2f_float_t m_v;
337 } m_S0;
338 ae2f_float_t m_f;
339 } m_U0;
340 };
341
342 ae2f_structdef(struct, ae2f_AnnMhattnBwdRAM_t) {
343 union ae2f_AnnMhattnBwdU0_t {
344 ae2f_float_t m_f;
345 ae2f_float_t m_fa[1];
346 } m_U0;
347 };
348
349 ae2f_reg ae2f_AnnMhattnBwdREG_t reg_bwd;
350 ae2f_reg ae2f_AnnMhattnBwdRAM_t ref_bwd;
351
352 /** ret_grad_wout is done here. */
353 (reg_bwd).m_i = (prm_mhattn).m_mdldist /* i */ * (prm_mhattn).m_mdldist /* k */;
354 while((reg_bwd).m_i--) {
355 (reg_bwd).m_j = (prm_seqlen) /* j */;
356 (reg_bwd).m_U0.m_f = 0;
357
358 while((reg_bwd).m_j--) {
359 (reg_bwd).m_U0.m_f
360 += (prm_attno)[ae2f_AnnMhattnHeadConcat_imp(
361 prm_mhattn, prm_seqlen
362 , (reg_bwd).m_j
363 , (reg_bwd).m_i / (prm_mhattn).m_mdldist) /* i */]
364 * (prm_grad_out)[ae2f_AnnUtilIdx2( /**/
365 (reg_bwd).m_j, (prm_seqlen)
366 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
367 )];
368 }
369
370 (ret_grad_wout)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_f;
371 }
372
373 (reg_bwd).m_i = (prm_seqlen) /* i */ * (prm_mhattn).m_mdldist /* k */;
374 while((reg_bwd).m_i--) {
375 (reg_bwd).m_j = (prm_mhattn).m_mdldist /* j */;
376 (reg_bwd).m_U0.m_f = 0;
377 while((reg_bwd).m_j--) {
378 /** i j * k j */
379 (reg_bwd).m_U0.m_f
380 += (prm_grad_out)[ae2f_AnnUtilIdx2(
381 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
382 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
383 )]
384 * (prm_wout)[ae2f_AnnUtilIdx2(/**/
385 (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist /**/
386 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
387 )]
388 ;
389 }
390
391 (ref_grad_heads)[
392 ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen
393 , (reg_bwd).m_i / (prm_mhattn).m_mdldist
394 , (reg_bwd).m_i % (prm_mhattn).m_mdldist
395 )
396 ] = (reg_bwd).m_U0.m_f;
397 }
398
399 (reg_bwd).m_hidx = (prm_mhattn).m_headc;
400 while((reg_bwd).m_hidx--) {
401 (reg_bwd).m_i = (prm_seqlen) /* i */ * ae2f_AnnMhattnKDist(prm_mhattn) /* k */;
402 while((reg_bwd).m_j--) {
403 (reg_bwd).m_U0.m_f = 0;
404 (reg_bwd).m_j = (prm_seqlen) /* j */;
405
406 while((reg_bwd).m_j--) {
407 (reg_bwd).m_U0.m_f
408 += (prm_attnw)[ae2f_AnnUtilIdx2(
409 (reg_bwd).m_i / ae2f_AnnMhattnKDist(prm_mhattn), (prm_seqlen)
410 , (reg_bwd).m_j, (prm_seqlen))
411 ] * /**/
412 (ref_grad_heads)[ae2f_AnnUtilIdx3(
413 (reg_bwd).m_hidx, (prm_mhattn).m_headc
414 , (reg_bwd).m_j
415 , (prm_seqlen)
416 , (reg_bwd).m_i % ae2f_AnnMhattnKDist(prm_mhattn)
417 , ae2f_AnnMhattnKDist(prm_mhattn)
418 )
419 ];
420 }
421
422 (ref_vcache)[
423 ae2f_AnnMhattnHeadSplit_imp(
424 prm_mhattn, prm_seqlen
425 , (reg_bwd).m_hidx
426 , (reg_bwd).m_i / ae2f_AnnMhattnKDist(prm_mhattn)
427 , (reg_bwd).m_i % ae2f_AnnMhattnKDist(prm_mhattn)
428 )
429 ] = (reg_bwd).m_U0.m_f;
430 }
431
432 (reg_bwd).m_i = (prm_seqlen) /* i */ * (prm_seqlen) /* k */;
433 while((reg_bwd).m_i--) {
434 (reg_bwd).m_j = ae2f_AnnMhattnKDist(prm_mhattn);
435 (reg_bwd).m_U0.m_f = 0;
436 while((reg_bwd).m_j--) {
437 (reg_bwd).m_U0.m_f +=
438 (ref_grad_heads)[ae2f_AnnUtilIdx3(
439 (reg_bwd).m_hidx, (prm_mhattn).m_headc
440 , (reg_bwd).m_i / (prm_seqlen), prm_seqlen
441 , (reg_bwd).m_j, ae2f_AnnMhattnKDist(prm_mhattn)
442 )] * /**/
443 (prm_val)[ae2f_AnnMhattnHeadSplit_imp(
444 prm_mhattn, prm_seqlen
445 , (reg_bwd).m_hidx
446 , (reg_bwd).m_i % (prm_seqlen)
447 , (reg_bwd).m_j
448 )];
449 }
450
451 (ref_grad_scores)[ae2f_AnnUtilIdx2(
452 (reg_bwd).m_i / (prm_seqlen), prm_seqlen
453 , (reg_bwd).m_i % (prm_seqlen), prm_seqlen)
454 ] = (reg_bwd).m_U0.m_f;
455 }
456
457 (reg_bwd).m_i = (prm_seqlen);
458 while((reg_bwd).m_i--) {
459 (reg_bwd).m_j = (prm_seqlen);
460 while((reg_bwd).m_j--) {
461 prm_actderiv(
462 /* ret */
463 &(ref_grad_scores)[ae2f_AnnUtilIdx2(
464 (reg_bwd).m_i, prm_seqlen,
465 (reg_bwd).m_j, prm_seqlen
466 )],
467 /* prm_retidx */
468 (reg_bwd).m_j,
469 /* prm_grad_in */
470 &(ref_grad_scores)[ae2f_AnnUtilIdx2(
471 (reg_bwd).m_i, prm_seqlen,
472 0, prm_seqlen
473 )],
474 /* prm_softmax_out */
475 &(prm_attnw)[ae2f_AnnUtilIdx3(
476 (reg_bwd).m_hidx, (prm_mhattn).m_headc,
477 (reg_bwd).m_i, prm_seqlen,
478 0, prm_seqlen
479 )],
480 /* prm_len */
481 prm_seqlen
482 );
483 } /** SEQLEN */
484 } /** SEQLEN (ACTDERIV) */
485
486 (reg_bwd).m_i = (prm_seqlen) /* i */ * ae2f_AnnMhattnKDist(prm_mhattn) /* k */;
487 while((reg_bwd).m_i--) {
488 (reg_bwd).m_j = (prm_seqlen);
489 (reg_bwd).m_U0.m_S0.m_k = 0;
490 (reg_bwd).m_U0.m_S0.m_q = 0;
491
492 while((reg_bwd).m_j--) {
493 (reg_bwd).m_U0.m_S0.m_q +=
494 (ref_grad_scores)[ae2f_AnnUtilIdx2(
495 (reg_bwd).m_i / ae2f_AnnMhattnKDist(prm_mhattn)
496 , (prm_seqlen)
497 , (reg_bwd).m_j, prm_seqlen)
498 ] * /**/
499 (prm_key)[ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen
500 , (reg_bwd).m_hidx
501 , (reg_bwd).m_j
502 , (reg_bwd).m_i % ae2f_AnnMhattnKDist(prm_mhattn))];
503
504 /** key(kdist, seqlen), scores(seqlen, seqlen) */
505 (reg_bwd).m_U0.m_S0.m_k +=
506 (prm_qry)[ae2f_AnnMhattnHeadSplit_imp(
507 prm_mhattn, prm_seqlen
508 , (reg_bwd).m_hidx
509 , (reg_bwd).m_j
510 , (reg_bwd).m_i / (prm_seqlen)
511 )] * /**/
512 (ref_grad_scores)[ae2f_AnnUtilIdx2(
513 (reg_bwd).m_j, (prm_seqlen)
514 , (reg_bwd).m_i % (prm_seqlen), prm_seqlen
515 )];
516 }
517
518 (ref_kcache)[ae2f_AnnMhattnHeadSplit_imp(
519 prm_mhattn, prm_seqlen
520 , (reg_bwd).m_hidx
521 , (reg_bwd).m_i / (prm_seqlen)
522 , (reg_bwd).m_i % (prm_seqlen)
523 )] = (reg_bwd).m_U0.m_S0.m_k;
524
525 (ref_qcache)[ae2f_AnnMhattnHeadSplit_imp(
526 prm_mhattn, prm_seqlen
527 , (reg_bwd).m_hidx
528 , (reg_bwd).m_i / (prm_seqlen)
529 , (reg_bwd).m_i % (prm_seqlen)
530 )] = (reg_bwd).m_U0.m_S0.m_q;
531 } /* CONV (GRAD_Q_HEAD, GRAD_K_HEAD) */
532 } /** HEADCOUNT */
533
534 (reg_bwd).m_i = (prm_mhattn).m_mdldist /* i */ * (prm_mhattn).m_mdldist /* k */;
535 while((reg_bwd).m_i--) {
536 (reg_bwd).m_U0.m_S0.m_k = 0;
537 (reg_bwd).m_U0.m_S0.m_q = 0;
538 (reg_bwd).m_U0.m_S0.m_v = 0;
539
540 (reg_bwd).m_j = (prm_seqlen);
541 while((reg_bwd).m_j--) {
542 (reg_bwd).m_U0.m_S0.m_q +=
543 (prm_qry)[ae2f_AnnUtilIdx2(
544 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
545 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
546 )] * /**/
547 (ref_qcache)[ae2f_AnnUtilIdx2(
548 (reg_bwd).m_j, (prm_mhattn).m_mdldist
549 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
550 )];
551
552 (reg_bwd).m_U0.m_S0.m_k +=
553 (prm_key)[ae2f_AnnUtilIdx2(
554 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
555 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
556 )] * /**/
557 (ref_kcache)[ae2f_AnnUtilIdx2(
558 (reg_bwd).m_j, (prm_mhattn).m_mdldist
559 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
560 )];
561
562 (reg_bwd).m_U0.m_S0.m_v +=
563 (prm_val)[ae2f_AnnUtilIdx2(
564 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
565 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
566 )] * /**/
567 (ref_vcache)[ae2f_AnnUtilIdx2(
568 (reg_bwd).m_j, (prm_mhattn).m_mdldist
569 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
570 )];
571 }
572
573 (ret_grad_wqry)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_S0.m_q;
574 (ret_grad_wkey)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_S0.m_k;
575 (ret_grad_wval)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_S0.m_v;
576 } \
577}
578
579#endif
580
581#undef __ae2f_MACRO_GENERATED
582
583#define __ae2f_MACRO_GENERATED 0
#define ae2f_AnnUtilIdx2(idx1, sz1, idx0, sz0)
Definition Util.h:23
#define ae2f_AnnUtilIdx3(idx2, sz2, idx1, sz1, idx0, sz0)
Definition Util.h:24
#define ae2f_structdef(key, name)
Definition Cast.h:110
#define ae2f_static_cast(t, v)
Definition Cast.h:42
#define ae2f_reg
Register keyword.
Definition Reg.h:12
#define ae2f_LP(...)
Definition Guide.h:23
#define ae2f_opt
Definition Guide.h:26
#define __ae2f_MACRO_GENERATED
Definition Conv.auto.h:2
#define ae2f_MAC_BUILD
Definition Util.h:5
#define ae2f_Ann_Mhattn_auto_h
Definition Mhattn.auto.h:4
#define ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen, m_i1, m_i0)
Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
Definition Mhattn.h:67
#define ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen, m_i2, m_i1, m_i0)
Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
Definition Mhattn.h:46
#define ae2f_AnnMhattnFwdSeqConvOne_imp( prm_seq, prm_w, prm_mdldist, prm_seqlen, prm_i, prm_j, prm_k)
Definition Mhattn.h:95
#define ae2f_AnnMhattnKDist(prm_mhattn)
m_headc * kdist == m_mdldist
Definition Mhattn.h:32
#define ae2f_NEED_CLASS
Definition Mlp.cl.c:8
#define ae2f_MAC(...)
Definition mac.h:28