ae2f_docs
Mhattn.h File Reference

A multi-head attention. More...

#include <ae2f/Cast.h>
#include <ae2f/Guide.h>
#include "./Act.h"
#include "./Mhattn.auto.h"
#include "./Mhattn.core.h"
#include "./Util.h"
#include <ae2f/Pack/Beg.h>

Go to the source code of this file.

Macros

#define ae2f_Ann_Mhattn_h
#define ae2f_AnnMhattnKDist(prm_mhattn)
 m_headc * kdist == m_mdldist
#define ae2f_AnnMhattnHeadSplit_imp(prm_mhattn, prm_seqlen, m_i2, m_i1, m_i0)
 Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].
#define ae2f_AnnMhattnHeadConcat_imp(prm_mhattn, prm_seqlen, m_i1, m_i0)
 Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].
#define ae2f_AnnMhattnFwdSeqConvOne_imp( prm_seq, prm_w, prm_mdldist, prm_seqlen, prm_i, prm_j, prm_k)

Functions

 ae2f_structdef (struct, ae2f_AnnMhattn_t)

Detailed Description

A multi-head attention.

Definition in file Mhattn.h.

Macro Definition Documentation

◆ ae2f_Ann_Mhattn_h

#define ae2f_Ann_Mhattn_h

Definition at line 11 of file Mhattn.h.

◆ ae2f_AnnMhattnFwdSeqConvOne_imp

#define ae2f_AnnMhattnFwdSeqConvOne_imp ( prm_seq,
prm_w,
prm_mdldist,
prm_seqlen,
prm_i,
prm_j,
prm_k )
Value:
((prm_seq)[ae2f_AnnUtilIdx2(prm_i, prm_seqlen, prm_j, prm_mdldist)] * \
(prm_w)[ae2f_AnnUtilIdx2(prm_j, prm_mdldist, prm_k, prm_mdldist)])

Original implementation by @kenter7317. Macrofied by @dalmurii.

Parameters
prm_seq{const ae2f_float_t* const} (prm_seqlen, prm_mdldist)
prm_w{const ae2f_float_t* const} (prm_mdldist, prm_mdldist)
prm_mdldist{const size_t}
prm_seqlen{const size_t}
prm_i{const size_t} < prm_seqlen
prm_j{const size_t} < prm_mdldist
prm_k{const size_t} < prm_mdldist

Definition at line 95 of file Mhattn.h.

◆ ae2f_AnnMhattnHeadConcat_imp

#define ae2f_AnnMhattnHeadConcat_imp ( prm_mhattn,
prm_seqlen,
m_i1,
m_i0 )
Value:
ae2f_AnnUtilIdx3( \
(m_i0) / ae2f_AnnMhattnKDist(prm_mhattn) \
, (prm_mhattn).m_headc \
, m_i1, prm_seqlen \
, (m_i0) % ae2f_AnnMhattnKDist(prm_mhattn) \
, ae2f_AnnMhattnKDist(prm_mhattn) \
)
#define ae2f_AnnMhattnKDist(prm_mhattn)
m_headc * kdist == m_mdldist
Definition Mhattn.h:32

Index redirector from [prm_seqlen, m_mdldist] to [m_headc, prm_seqlen, kdist].

Returns
{const size_t}
Parameters
prm_mhattn{const ae2f_AnnMhattn_t&}
prm_seqlen{const size_t}
m_i1{const size_t} < prm_seqlen
m_i0{const size_t} < m_mdldist

Definition at line 67 of file Mhattn.h.

◆ ae2f_AnnMhattnHeadSplit_imp

#define ae2f_AnnMhattnHeadSplit_imp ( prm_mhattn,
prm_seqlen,
m_i2,
m_i1,
m_i0 )
Value:
ae2f_AnnUtilIdx2( \
m_i1, prm_seqlen \
, ((m_i0) + (m_i2) * ae2f_AnnMhattnKDist(prm_mhattn)) \
, (prm_mhattn).m_mdldist)

Index redirector from [m_headc, prm_seqlen, kdist] to [prm_seqlen, m_mdldist].

Returns
{const size_t}
Parameters
prm_mhattn{const ae2f_AnnMhattn_t&}
prm_seqlen{const size_t}
m_i2{const size_t} < m_headc
m_i1{const size_t} < prm_seqlen
m_i0{const size_t} < kdist

Definition at line 46 of file Mhattn.h.

◆ ae2f_AnnMhattnKDist

#define ae2f_AnnMhattnKDist ( prm_mhattn)
Value:
((prm_mhattn).m_mdldist / (prm_mhattn).m_headc)

m_headc * kdist == m_mdldist

const ae2f_AnnMhattn_t

Definition at line 32 of file Mhattn.h.

Function Documentation

◆ ae2f_structdef()

ae2f_structdef ( struct ,
ae2f_AnnMhattn_t  )

Metadata for Multihead Attention

Definition at line 25 of file Mhattn.h.