1#include <ae2f/Ann/Mlp.h>
6static void Act(ae2f_float_t* r,
const ae2f_float_t* x, size_t i, size_t) {
7 *r = 1.0 / (1.0 + exp(-x[i]));
10static void ActDeriv(ae2f_float_t* r,
const ae2f_float_t* output, size_t i, size_t) {
11 *r = output[i] * (1.0 - output[i]);
14static void LossDerivCROSS(ae2f_float_t* r,
const ae2f_float_t* output,
const ae2f_float_t* target, size_t i, size_t c) {
15 const ae2f_float_t epsilon = 1e-7;
16 ae2f_float_t o_i = output[i];
17 o_i = o_i < epsilon ? epsilon : (o_i > 1.0 - epsilon ? 1.0 - epsilon : o_i);
18 r[0] = (o_i - target[i]) / (c * o_i * (1.0 - o_i));
21const ae2f_float_t inp[4][2] = {
28const ae2f_float_t goal_xor[4] = {0, 1, 1, 0};
30ae2f_float_t output[2] = {0};
33size_t lenv[] = {2, 5, 5, 1};
37 ae2f_AnnMlpMk(err, &mlp, 4, lenv, 0, 0, 0, LossDerivCROSS, 0, 0, 0, 0, 0.5, 0.5, 13, 9);
40 printf(
"[Error]: %d\n", err[0]);
45 for (size_t i = 0; i < mlp->m_depth - 1; i++) {
47 mlp->m_actderiv[i] = ActDeriv;
48 size_t outc = lenv[i + 1];
51 for (size_t j = 0; j < outc; j++) {
52 mlp->m_bias[i * mlp->m_outc + j] = ((ae2f_float_t)rand() / RAND_MAX - 0.5) * 2.0;
55 for (size_t j = 0; j < inc * outc; j++) {
56 mlp->m_weight[i * mlp->m_outc * mlp->m_outc + j] = ((ae2f_float_t)rand() / RAND_MAX - 0.5) * 2.0;
61 for (size_t j = 0; j < 4; j++) {
62 mlp->Predict(err, inp[j], output);
64 printf(
"[Error in Predict]: %d\n", err[0]);
68 printf(
"Before train: %f %f -> %f (target: %f)\n", inp[j][0], inp[j][1], output[0], goal_xor[j]);
72 for (size_t i = 0; i < 10000; i++) {
73 for (size_t j = 0; j < 4; j++) {
74 mlp->TrainAuto(err, inp[j], goal_xor + j);
76 printf(
"[Error in TrainAutoStream]: %d\n", err[0]);
82 printf(
"\nEpoch %zu:\n", i);
83 for (size_t j = 0; j < 4; j++) {
85 mlp->Predict(err, inp[j], output);
87 printf(
"[Error in PredictStream]: %d\n", err[0]);
91 printf(
"\t%f %f -> %f (target: %f)\n", inp[j][0], inp[j][1], output[0], goal_xor[j]);
97 printf(
"\nFinal results:\n");
98 for (size_t j = 0; j < 4; j++) {
100 mlp->Predict(err, inp[j], output);
102 printf(
"[Error in Predict]: %d\n", err[0]);
106 printf(
"%zu, %f %f -> %f (target: %f)\n", j, inp[j][0], inp[j][1], output[0], goal_xor[j]);
uint8_t ae2f_err_t
Informs that this number represents the error.