ae2f_docs
Mhattn.h
Go to the documentation of this file.
1#ifndef ae2f_NEED_CLASS
2#define ae2f_NEED_CLASS 1
3#endif
4
5/**
6 * @file Mhattn.h
7 * @brief
8 * A multi-head attention.
9 * */
10#ifndef ae2f_Ann_Mhattn_h
11#define ae2f_Ann_Mhattn_h
12
13#include <ae2f/Cast.h>
14#include <ae2f/Guide.h>
15#include "./Act.h"
16
17#include "./Mhattn.auto.h"
18#include "./Mhattn.core.h"
19
20#include "./Util.h"
21
22#include <ae2f/Pack/Beg.h>
23
24/** Metadata for Multihead Attention */
25ae2f_structdef(struct, ae2f_AnnMhattn_t)
26{
27 size_t m_mdldist, m_headc;
28};
29
30
31/** @brief m_headc * kdist == m_mdldist */
32#define ae2f_AnnMhattnKDist(/** const ae2f_AnnMhattn_t */ prm_mhattn)
33 ((prm_mhattn).m_mdldist / (prm_mhattn).m_headc)
34
35/**
36 * @brief
37 * Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist]
38 *
39 * @returns {const size_t}
40 * @param prm_mhattn {const ae2f_AnnMhattn_t&}
41 * @param prm_seqlen {const size_t}
42 * @param m_i2 {const size_t} < m_headc
43 * @param m_i1 {const size_t} < prm_seqlen
44 * @param m_i0 {const size_t} < kdist
45 * */
46#define ae2f_AnnMhattnHeadSplit_imp(
47 prm_mhattn,
48 prm_seqlen,
49 m_i2,
50 m_i1,
51 m_i0
53 m_i1, prm_seqlen
54 , ((m_i0) + (m_i2) * ae2f_AnnMhattnKDist(prm_mhattn))
55 , (prm_mhattn).m_mdldist)
56
57/**
58 * @brief
59 * Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist]
60 *
61 * @returns {const size_t}
62 * @param prm_mhattn {const ae2f_AnnMhattn_t&}
63 * @param prm_seqlen {const size_t}
64 * @param m_i1 {const size_t} < prm_seqlen
65 * @param m_i0 {const size_t} < m_mdldist
66 * */
67#define ae2f_AnnMhattnHeadConcat_imp(
68 prm_mhattn,
69 prm_seqlen,
70 m_i1,
71 m_i0
73 (m_i0) / ae2f_AnnMhattnKDist(prm_mhattn)
74 , (prm_mhattn).m_headc
75 , m_i1, prm_seqlen
76 , (m_i0) % ae2f_AnnMhattnKDist(prm_mhattn)
77 , ae2f_AnnMhattnKDist(prm_mhattn)
78 )
79
80/**
81 * @brief
82 *
83 * @details
84 * Original implementation by @kenter7317.
85 * Macrofied by @dalmurii.
86 *
87 * @param prm_seq {const ae2f_float_t* const} (prm_seqlen, prm_mdldist)
88 * @param prm_w {const ae2f_float_t* const} (prm_mdldist, prm_mdldist)
89 * @param prm_mdldist {const size_t}
90 * @param prm_seqlen {const size_t}
91 * @param prm_i {const size_t} < prm_seqlen
92 * @param prm_j {const size_t} < prm_mdldist
93 * @param prm_k {const size_t} < prm_mdldist
94 * */
95#define ae2f_AnnMhattnFwdSeqConvOne_imp(
96 prm_seq,
97 prm_w,
98 prm_mdldist,
99 prm_seqlen,
100 prm_i,
101 prm_j,
102 prm_k
103 )
104 ((prm_seq)[ae2f_AnnUtilIdx2(prm_i, prm_seqlen, prm_j, prm_mdldist)] *
105 (prm_w)[ae2f_AnnUtilIdx2(prm_j, prm_mdldist, prm_k, prm_mdldist)])
106
107
109
110/**
111 * @brief
112 * [M]ulti-[h]ead [att]e[n]tion.
113 * */
114ae2f_structdef(struct, ae2f_AnnMhattn) {
115
116 /** Weights */
117 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wqry;
118 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wkey;
119 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wval;
120 ae2f_LP(m_mdldist, m_mdldist) ae2f_float_t* ae2f_restrict m_wout;
121
122 /** @brief Model distance */
123 size_t m_mdldist;
124
125 /** @brief Head count */
126 size_t m_headc;
127
128 ae2f_AnnLossFFN_t* m_loss;
129};
130
131#include <ae2f/Pack/End.h>
132
133#endif
134
135#endif
#define ae2f_AnnUtilIdx2(idx1, sz1, idx0, sz0)
Definition Util.h:23
#define ae2f_AnnUtilIdx3(idx2, sz2, idx1, sz1, idx0, sz0)
Definition Util.h:24
#define ae2f_structdef(key, name)
Definition Cast.h:110
#define ae2f_static_cast(t, v)
Definition Cast.h:42
#define ae2f_reg
Register keyword.
Definition Reg.h:12
#define ae2f_LP(...)
Definition Guide.h:23
#define ae2f_opt
Definition Guide.h:26
#define __ae2f_MACRO_GENERATED
Definition Conv.auto.h:2
#define ae2f_MAC_BUILD
Definition Util.h:5
#define ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen, m_i1, m_i0)
Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
Definition Mhattn.h:67
#define ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen, m_i2, m_i1, m_i0)
Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
Definition Mhattn.h:46
#define ae2f_AnnMhattnFwdSeqConvOne_imp( prm_seq, prm_w, prm_mdldist, prm_seqlen, prm_i, prm_j, prm_k)
Definition Mhattn.h:95
#define ae2f_AnnMhattnKDist(prm_mhattn)
m_headc * kdist == m_mdldist
Definition Mhattn.h:32
#define ae2f_NEED_CLASS
Definition Mlp.cl.c:8
#define ae2f_MAC(...)
Definition mac.h:28