1#undef __ae2f_MACRO_GENERATED
2#define __ae2f_MACRO_GENERATED 1
4#define ae2f_Ann_Mhattn_auto_h
6#include <ae2f/Ann/Mhattn.h>
7#undef __ae2f_MACRO_GENERATED
8#define __ae2f_MACRO_GENERATED 1
10#if !ae2f_MAC_BUILD || !__ae2f_MACRO_GENERATED
12#undef __ae2f_MACRO_GENERATED
13#define __ae2f_MACRO_GENERATED 1
15#undef __ae2f_MACRO_GENERATED
16#define __ae2f_MACRO_GENERATED 1
18#undef __ae2f_MACRO_GENERATED
19#define __ae2f_MACRO_GENERATED 1
22#if !__ae2f_MACRO_GENERATED
23#include <ae2f/Macro.h>
24#undef __ae2f_MACRO_GENERATED
25#define __ae2f_MACRO_GENERATED 1
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48#define _ae2f_AnnMhattnFwd_imp(
72)\
73{
75 assert(ref_kcache && "ae2f_AnnMhattnFwd_imp");
76 assert(ref_qcache && "ae2f_AnnMhattnFwd_imp");
77 assert(ref_vcache && "ae2f_AnnMhattnFwd_imp");
78 assert(ref_attnw_cache && "ae2f_AnnMhattnFwd_imp");
80 assert(prm_wqry && "ae2f_AnnMhattnFwd_imp");
81 assert(prm_wkey && "ae2f_AnnMhattnFwd_imp");
82 assert(prm_wval && "ae2f_AnnMhattnFwd_imp");
83 assert(prm_wout && "ae2f_AnnMhattnFwd_imp");
85 assert(prm_qry && "ae2f_AnnMhattnFwd_imp");
86 assert(prm_key && "ae2f_AnnMhattnFwd_imp");
87 assert(prm_val && "ae2f_AnnMhattnFwd_imp");
89 assert(ret_out && "ae2f_AnnMhattnFwd_imp");
90 assert(ret_attnw && "ae2f_AnnMhattnFwd_imp");
91 assert(ret_attno && "ae2f_AnnMhattnFwd_imp");
94 (ref_mem).m_i = (prm_seqlen) * ((prm_mhattn).m_mdldist) ;
97
98
99 while((ref_mem).m_i--) {
100 (ref_mem).m_U0.m_S0.m_k = 0
;
101 (ref_mem).m_U0.m_S0.m_q = 0
;
102 (ref_mem).m_U0.m_S0.m_v = 0
;
104 for((ref_mem).m_j = (prm_mhattn).m_mdldist ; (ref_mem).m_j--;) {
105 (ref_mem).m_U0.m_S0.m_q
106 += ae2f_AnnMhattnFwdSeqConvOne_imp(
109 , (prm_mhattn).m_mdldist
111 , (ref_mem).m_i / (prm_mhattn).m_mdldist
113 , (ref_mem).m_i % (prm_mhattn).m_mdldist
116 (ref_mem).m_U0.m_S0.m_k
117 += ae2f_AnnMhattnFwdSeqConvOne_imp(
120 , (prm_mhattn).m_mdldist
122 , (ref_mem).m_i / (prm_mhattn).m_mdldist
124 , (ref_mem).m_i % (prm_mhattn).m_mdldist
127 (ref_mem).m_U0.m_S0.m_v
128 += ae2f_AnnMhattnFwdSeqConvOne_imp(
131 , (prm_mhattn).m_mdldist
133 , (ref_mem).m_i / (prm_mhattn).m_mdldist
135 , (ref_mem).m_i % (prm_mhattn).m_mdldist
139 (ref_qcache)[(ref_mem).m_i] = (ref_mem).m_U0.m_S0.m_q;
140 (ref_kcache)[(ref_mem).m_i] = (ref_mem).m_U0.m_S0.m_k;
141 (ref_vcache)[(ref_mem).m_i] = (ref_mem).m_U0.m_S0.m_v;
145
146
147
148 for((ref_mem).m_i = (prm_mhattn).m_headc; (ref_mem).m_i--;) {
149 (ref_mem).m_j = (prm_seqlen) * (prm_seqlen) ;
151
152
153 while((ref_mem).m_j--) {
154 (ref_mem).m_U0.m_one = 0
;
155 (ref_mem).m_k = ae2f_AnnMhattnKDist(prm_mhattn) ;
157 while((ref_mem).m_k--) {
160 += (ref_qcache)[ae2f_AnnMhattnHeadSplit_imp(
164 , (ref_mem).m_j / (prm_seqlen)
167 * (ref_kcache)[ae2f_AnnMhattnHeadSplit_imp(
171 , (ref_mem).m_j % (prm_seqlen)
176 (ref_attnw_cache)[ae2f_AnnUtilIdx3(
178 , (prm_mhattn).m_headc
184 ((ref_mem).m_U0.m_one / sqrt(
187 , ae2f_AnnMhattnKDist(prm_mhattn)))
189 + ((prm_mask_opt) ? (prm_mask_opt)[(ref_mem).m_j] : 0.
);
192 (ref_mem).m_j = (prm_seqlen);
193 while((ref_mem).m_j--) {
194 (ref_mem).m_k = (prm_seqlen);
195 while((ref_mem).m_k--) {
198 &(ret_attnw)[ae2f_AnnUtilIdx3(
200 , (prm_mhattn).m_headc
201 , (ref_mem).m_j, (prm_seqlen)
202 , (ref_mem).m_k, (prm_seqlen)
207 &(ref_attnw_cache)[ae2f_AnnUtilIdx3(
208 (ref_mem).m_i, (prm_mhattn).m_headc
209 , (ref_mem).m_j, (prm_seqlen)
218 (ref_mem).m_j = (prm_seqlen) * ae2f_AnnMhattnKDist(prm_mhattn) ;
219 while((ref_mem).m_j--)
221 (ref_mem).m_U0.m_one = 0
;
222 for((ref_mem).m_k = (prm_seqlen) ; (ref_mem).m_k--;)
225 += (ret_attnw)[ae2f_AnnUtilIdx3(
226 (ref_mem).m_i, (prm_mhattn).m_headc
227 , (ref_mem).m_j / ae2f_AnnMhattnKDist(prm_mhattn), prm_seqlen
228 , (ref_mem).m_k, prm_seqlen
230 (ref_vcache)[ae2f_AnnMhattnHeadSplit_imp(
235 , (ref_mem).m_j % ae2f_AnnMhattnKDist(prm_mhattn)
241 (ref_mem).m_i, (prm_mhattn).m_headc
243 , 0
, ae2f_AnnMhattnKDist(prm_mhattn)
245 ] = (ref_mem).m_U0.m_one;
249 (ref_mem).m_i = (prm_seqlen) * (prm_mhattn).m_mdldist ;
250 while((ref_mem).m_i--) {
251 (ref_mem).m_j = ((prm_mhattn).m_mdldist) ;
252 (ref_mem).m_U0.m_one = 0
;
254 while((ref_mem).m_j--) {
255 (ref_mem).m_U0.m_one +=
257 ae2f_AnnMhattnHeadConcat_imp(
260 , (ref_mem).m_i / (prm_mhattn).m_mdldist
263 (prm_wout)[ae2f_AnnUtilIdx2(
264 (ref_mem).m_j, (prm_mhattn).m_mdldist
265 , (ref_mem).m_i % (prm_mhattn).m_mdldist
266 , (prm_mhattn).m_mdldist)];
269 (ret_out)[(ref_mem).m_i] = (ref_mem).m_U0.m_one;
271}
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299#define _ae2f_AnnMhattnBwd_imp(
327)\
328{
329 ae2f_structdef(struct, ae2f_AnnMhattnBwdREG_t) {
330 size_t m_i, m_j, m_hidx;
332 union ae2f_AnnMhattnBwdU0REG_t {
333 struct ae2f_AnnMhattnBwdU0S0REG_t {
342 ae2f_structdef(struct, ae2f_AnnMhattnBwdRAM_t) {
343 union ae2f_AnnMhattnBwdU0_t {
345 ae2f_float_t m_fa[1
];
349 ae2f_reg ae2f_AnnMhattnBwdREG_t reg_bwd;
350 ae2f_reg ae2f_AnnMhattnBwdRAM_t ref_bwd;
353 (reg_bwd).m_i = (prm_mhattn).m_mdldist * (prm_mhattn).m_mdldist ;
354 while((reg_bwd).m_i--) {
355 (reg_bwd).m_j = (prm_seqlen) ;
356 (reg_bwd).m_U0.m_f = 0
;
358 while((reg_bwd).m_j--) {
360 += (prm_attno)[ae2f_AnnMhattnHeadConcat_imp(
361 prm_mhattn, prm_seqlen
363 , (reg_bwd).m_i / (prm_mhattn).m_mdldist) ]
364 * (prm_grad_out)[ae2f_AnnUtilIdx2(
365 (reg_bwd).m_j, (prm_seqlen)
366 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
370 (ret_grad_wout)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_f;
373 (reg_bwd).m_i = (prm_seqlen) * (prm_mhattn).m_mdldist ;
374 while((reg_bwd).m_i--) {
375 (reg_bwd).m_j = (prm_mhattn).m_mdldist ;
376 (reg_bwd).m_U0.m_f = 0
;
377 while((reg_bwd).m_j--) {
380 += (prm_grad_out)[ae2f_AnnUtilIdx2(
381 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
382 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
384 * (prm_wout)[ae2f_AnnUtilIdx2(
385 (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
386 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
392 ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen
393 , (reg_bwd).m_i / (prm_mhattn).m_mdldist
394 , (reg_bwd).m_i % (prm_mhattn).m_mdldist
396 ] = (reg_bwd).m_U0.m_f;
399 (reg_bwd).m_hidx = (prm_mhattn).m_headc;
400 while((reg_bwd).m_hidx--) {
401 (reg_bwd).m_i = (prm_seqlen) * ae2f_AnnMhattnKDist(prm_mhattn) ;
402 while((reg_bwd).m_j--) {
403 (reg_bwd).m_U0.m_f = 0
;
404 (reg_bwd).m_j = (prm_seqlen) ;
406 while((reg_bwd).m_j--) {
408 += (prm_attnw)[ae2f_AnnUtilIdx2(
409 (reg_bwd).m_i / ae2f_AnnMhattnKDist(prm_mhattn), (prm_seqlen)
410 , (reg_bwd).m_j, (prm_seqlen))
412 (ref_grad_heads)[ae2f_AnnUtilIdx3(
413 (reg_bwd).m_hidx, (prm_mhattn).m_headc
416 , (reg_bwd).m_i % ae2f_AnnMhattnKDist(prm_mhattn)
417 , ae2f_AnnMhattnKDist(prm_mhattn)
423 ae2f_AnnMhattnHeadSplit_imp(
424 prm_mhattn, prm_seqlen
426 , (reg_bwd).m_i / ae2f_AnnMhattnKDist(prm_mhattn)
427 , (reg_bwd).m_i % ae2f_AnnMhattnKDist(prm_mhattn)
429 ] = (reg_bwd).m_U0.m_f;
432 (reg_bwd).m_i = (prm_seqlen) * (prm_seqlen) ;
433 while((reg_bwd).m_i--) {
434 (reg_bwd).m_j = ae2f_AnnMhattnKDist(prm_mhattn);
435 (reg_bwd).m_U0.m_f = 0
;
436 while((reg_bwd).m_j--) {
437 (reg_bwd).m_U0.m_f +=
438 (ref_grad_heads)[ae2f_AnnUtilIdx3(
439 (reg_bwd).m_hidx, (prm_mhattn).m_headc
440 , (reg_bwd).m_i / (prm_seqlen), prm_seqlen
441 , (reg_bwd).m_j, ae2f_AnnMhattnKDist(prm_mhattn)
443 (prm_val)[ae2f_AnnMhattnHeadSplit_imp(
444 prm_mhattn, prm_seqlen
446 , (reg_bwd).m_i % (prm_seqlen)
451 (ref_grad_scores)[ae2f_AnnUtilIdx2(
452 (reg_bwd).m_i / (prm_seqlen), prm_seqlen
453 , (reg_bwd).m_i % (prm_seqlen), prm_seqlen)
454 ] = (reg_bwd).m_U0.m_f;
457 (reg_bwd).m_i = (prm_seqlen);
458 while((reg_bwd).m_i--) {
459 (reg_bwd).m_j = (prm_seqlen);
460 while((reg_bwd).m_j--) {
463 &(ref_grad_scores)[ae2f_AnnUtilIdx2(
464 (reg_bwd).m_i, prm_seqlen,
465 (reg_bwd).m_j, prm_seqlen
470 &(ref_grad_scores)[ae2f_AnnUtilIdx2(
471 (reg_bwd).m_i, prm_seqlen,
475 &(prm_attnw)[ae2f_AnnUtilIdx3(
476 (reg_bwd).m_hidx, (prm_mhattn).m_headc,
477 (reg_bwd).m_i, prm_seqlen,
486 (reg_bwd).m_i = (prm_seqlen) * ae2f_AnnMhattnKDist(prm_mhattn) ;
487 while((reg_bwd).m_i--) {
488 (reg_bwd).m_j = (prm_seqlen);
489 (reg_bwd).m_U0.m_S0.m_k = 0
;
490 (reg_bwd).m_U0.m_S0.m_q = 0
;
492 while((reg_bwd).m_j--) {
493 (reg_bwd).m_U0.m_S0.m_q +=
494 (ref_grad_scores)[ae2f_AnnUtilIdx2(
495 (reg_bwd).m_i / ae2f_AnnMhattnKDist(prm_mhattn)
497 , (reg_bwd).m_j, prm_seqlen)
499 (prm_key)[ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen
502 , (reg_bwd).m_i % ae2f_AnnMhattnKDist(prm_mhattn))];
505 (reg_bwd).m_U0.m_S0.m_k +=
506 (prm_qry)[ae2f_AnnMhattnHeadSplit_imp(
507 prm_mhattn, prm_seqlen
510 , (reg_bwd).m_i / (prm_seqlen)
512 (ref_grad_scores)[ae2f_AnnUtilIdx2(
513 (reg_bwd).m_j, (prm_seqlen)
514 , (reg_bwd).m_i % (prm_seqlen), prm_seqlen
518 (ref_kcache)[ae2f_AnnMhattnHeadSplit_imp(
519 prm_mhattn, prm_seqlen
521 , (reg_bwd).m_i / (prm_seqlen)
522 , (reg_bwd).m_i % (prm_seqlen)
523 )] = (reg_bwd).m_U0.m_S0.m_k;
525 (ref_qcache)[ae2f_AnnMhattnHeadSplit_imp(
526 prm_mhattn, prm_seqlen
528 , (reg_bwd).m_i / (prm_seqlen)
529 , (reg_bwd).m_i % (prm_seqlen)
530 )] = (reg_bwd).m_U0.m_S0.m_q;
534 (reg_bwd).m_i = (prm_mhattn).m_mdldist * (prm_mhattn).m_mdldist ;
535 while((reg_bwd).m_i--) {
536 (reg_bwd).m_U0.m_S0.m_k = 0
;
537 (reg_bwd).m_U0.m_S0.m_q = 0
;
538 (reg_bwd).m_U0.m_S0.m_v = 0
;
540 (reg_bwd).m_j = (prm_seqlen);
541 while((reg_bwd).m_j--) {
542 (reg_bwd).m_U0.m_S0.m_q +=
543 (prm_qry)[ae2f_AnnUtilIdx2(
544 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
545 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
547 (ref_qcache)[ae2f_AnnUtilIdx2(
548 (reg_bwd).m_j, (prm_mhattn).m_mdldist
549 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
552 (reg_bwd).m_U0.m_S0.m_k +=
553 (prm_key)[ae2f_AnnUtilIdx2(
554 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
555 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
557 (ref_kcache)[ae2f_AnnUtilIdx2(
558 (reg_bwd).m_j, (prm_mhattn).m_mdldist
559 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
562 (reg_bwd).m_U0.m_S0.m_v +=
563 (prm_val)[ae2f_AnnUtilIdx2(
564 (reg_bwd).m_i / (prm_mhattn).m_mdldist, prm_seqlen
565 , (reg_bwd).m_j, (prm_mhattn).m_mdldist
567 (ref_vcache)[ae2f_AnnUtilIdx2(
568 (reg_bwd).m_j, (prm_mhattn).m_mdldist
569 , (reg_bwd).m_i % (prm_mhattn).m_mdldist, (prm_mhattn).m_mdldist
573 (ret_grad_wqry)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_S0.m_q;
574 (ret_grad_wkey)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_S0.m_k;
575 (ret_grad_wval)[(reg_bwd).m_i] = (reg_bwd).m_U0.m_S0.m_v;
577}
581#undef __ae2f_MACRO_GENERATED
583#define __ae2f_MACRO_GENERATED 0
#define ae2f_AnnUtilIdx2(idx1, sz1, idx0, sz0)
#define ae2f_AnnUtilIdx3(idx2, sz2, idx1, sz1, idx0, sz0)
#define ae2f_structdef(key, name)
#define ae2f_static_cast(t, v)
#define ae2f_reg
Register keyword.
#define __ae2f_MACRO_GENERATED
#define ae2f_Ann_Mhattn_auto_h
#define ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen, m_i1, m_i0)
Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
#define ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen, m_i2, m_i1, m_i0)
Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
#define ae2f_AnnMhattnFwdSeqConvOne_imp( prm_seq, prm_w, prm_mdldist, prm_seqlen, prm_i, prm_j, prm_k)
#define ae2f_AnnMhattnKDist(prm_mhattn)
m_headc * kdist == m_mdldist