A multi-head attention.
More...
#include <ae2f/Cast.h>
#include <ae2f/Guide.h>
#include "./Act.h"
#include "./Mhattn.auto.h"
#include "./Mhattn.core.h"
#include "./Util.h"
#include <ae2f/Pack/Beg.h>
Go to the source code of this file.
|
| #define | ae2f_Ann_Mhattn_h |
| #define | ae2f_AnnMhattnKDist(prm_mhattn) |
| | m_headc * kdist == m_mdldist
|
| #define | ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen, m_i2, m_i1, m_i0) |
| | Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
|
| #define | ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen, m_i1, m_i0) |
| | Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
|
| #define | ae2f_AnnMhattnFwdSeqConvOne_imp( prm_seq, prm_w, prm_mdldist, prm_seqlen, prm_i, prm_j, prm_k) |
A multi-head attention.
Definition in file Mhattn.h.
◆ ae2f_Ann_Mhattn_h
| #define ae2f_Ann_Mhattn_h |
◆ ae2f_AnnMhattnFwdSeqConvOne_imp
| #define ae2f_AnnMhattnFwdSeqConvOne_imp |
( |
| prm_seq, |
|
|
| prm_w, |
|
|
| prm_mdldist, |
|
|
| prm_seqlen, |
|
|
| prm_i, |
|
|
| prm_j, |
|
|
| prm_k ) |
Value: ((prm_seq)[ae2f_AnnUtilIdx2(prm_i, prm_seqlen, prm_j, prm_mdldist)] * \
(prm_w)[ae2f_AnnUtilIdx2(prm_j, prm_mdldist, prm_k, prm_mdldist)])
Original implementation by @kenter7317. Macrofied by @dalmurii.
- Parameters
-
| prm_seq | {const ae2f_float_t* const} (prm_seqlen, prm_mdldist) |
| prm_w | {const ae2f_float_t* const} (prm_mdldist, prm_mdldist) |
| prm_mdldist | {const size_t} |
| prm_seqlen | {const size_t} |
| prm_i | {const size_t} < prm_seqlen |
| prm_j | {const size_t} < prm_mdldist |
| prm_k | {const size_t} < prm_mdldist |
Definition at line 95 of file Mhattn.h.
◆ ae2f_AnnMhattnHeadConcat_imp
| #define ae2f_AnnMhattnHeadConcat_imp |
( |
| prm_mhattn, |
|
|
| prm_seqlen, |
|
|
| m_i1, |
|
|
| m_i0 ) |
Value: ae2f_AnnUtilIdx3( \
, (prm_mhattn).m_headc \
, m_i1, prm_seqlen \
)
#define ae2f_AnnMhattnKDist(prm_mhattn)
m_headc * kdist == m_mdldist
Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
- Returns
- {const size_t}
- Parameters
-
| prm_mhattn | {const ae2f_AnnMhattn_t&} |
| prm_seqlen | {const size_t} |
| m_i1 | {const size_t} < prm_seqlen |
| m_i0 | {const size_t} < m_mdldist |
Definition at line 67 of file Mhattn.h.
◆ ae2f_AnnMhattnHeadSplit_imp
| #define ae2f_AnnMhattnHeadSplit_imp |
( |
| prm_mhattn, |
|
|
| prm_seqlen, |
|
|
| m_i2, |
|
|
| m_i1, |
|
|
| m_i0 ) |
Value: ae2f_AnnUtilIdx2( \
m_i1, prm_seqlen \
, (prm_mhattn).m_mdldist)
Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
- Returns
- {const size_t}
- Parameters
-
| prm_mhattn | {const ae2f_AnnMhattn_t&} |
| prm_seqlen | {const size_t} |
| m_i2 | {const size_t} < m_headc |
| m_i1 | {const size_t} < prm_seqlen |
| m_i0 | {const size_t} < kdist |
Definition at line 46 of file Mhattn.h.
◆ ae2f_AnnMhattnKDist
| #define ae2f_AnnMhattnKDist |
( |
| prm_mhattn | ) |
|
Value: ((prm_mhattn).m_mdldist / (prm_mhattn).m_headc)
m_headc * kdist == m_mdldist
const ae2f_AnnMhattn_t
Definition at line 32 of file Mhattn.h.
◆ ae2f_structdef()
| ae2f_structdef |
( |
struct | , |
|
|
ae2f_AnnMhattn_t | ) |
Metadata for Multihead Attention
Definition at line 25 of file Mhattn.h.