2#define ae2f_NEED_CLASS 1
6
7
8
9
10#ifndef ae2f_Ann_Mhattn_h
11#define ae2f_Ann_Mhattn_h
14#include <ae2f/Guide.h>
17#include "./Mhattn.auto.h"
18#include "./Mhattn.core.h"
22#include <ae2f/Pack/Beg.h>
27 size_t m_mdldist, m_headc;
32#define ae2f_AnnMhattnKDist( prm_mhattn)
33 ((prm_mhattn).m_mdldist / (prm_mhattn).m_headc)
36
37
38
39
40
41
42
43
44
45
46#define ae2f_AnnMhattnHeadSplit_imp(
55 , (prm_mhattn).m_mdldist)
58
59
60
61
62
63
64
65
66
67#define ae2f_AnnMhattnHeadConcat_imp(
74 , (prm_mhattn).m_headc
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95#define ae2f_AnnMhattnFwdSeqConvOne_imp(
111
112
113
117 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wqry;
118 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wkey;
119 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wval;
120 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wout;
128 ae2f_AnnLossFFN_t* m_loss;
131#include <ae2f/Pack/End.h>
#define ae2f_AnnUtilIdx2(idx1, sz1, idx0, sz0)
#define ae2f_AnnUtilIdx3(idx2, sz2, idx1, sz1, idx0, sz0)
#define ae2f_structdef(key, name)
#define ae2f_static_cast(t, v)
#define ae2f_reg
Register keyword.
#define __ae2f_MACRO_GENERATED
#define ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen, m_i1, m_i0)
Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
#define ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen, m_i2, m_i1, m_i0)
Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
#define ae2f_AnnMhattnFwdSeqConvOne_imp( prm_seq, prm_w, prm_mdldist, prm_seqlen, prm_i, prm_j, prm_k)
#define ae2f_AnnMhattnKDist(prm_mhattn)
m_headc * kdist == m_mdldist