ae2f_docs
MlpTrainXOR.cc
Go to the documentation of this file.
1#include <ae2f/Ann/Mlp.h>
2#include <cstdlib>
3#include <cstdio>
4#include <math.h>
5
6static void Act(ae2f_float_t* r, ae2f_float_t x) {
7 *r = 1.0 / (1.0 + exp(-x));
8}
9
10static void ActDeriv(ae2f_float_t* r, ae2f_float_t output) {
11 *r = output * (1.0 - output);
12}
13
14static void LossDerivCROSS(ae2f_float_t* r, const ae2f_float_t* output, const ae2f_float_t* target, size_t i, size_t c) {
15 const ae2f_float_t epsilon = 1e-7;
16 ae2f_float_t o_i = output[i];
17 o_i = o_i < epsilon ? epsilon : (o_i > 1.0 - epsilon ? 1.0 - epsilon : o_i);
18 r[0] = (o_i - target[i]) / (c * o_i * (1.0 - o_i));
19}
20
21const ae2f_float_t inp[4][2] = {
22 {0, 0},
23 {0, 1},
24 {1, 0},
25 {1, 1}
26};
27
28const ae2f_float_t goal_xor[4] = {0, 1, 1, 0}; // Fixed -0 to 0
29
31ae2f_err_t err[1] = {0};
32ae2f_AnnMlp* mlp;
33size_t lenv[] = {2, 5, 5, 1}; // Simplified architecture: 2 input, 4 hidden, 1 output
34
35int main() {
36 // Use cross-entropy loss and higher learning rate
37 ae2f_AnnMlpMk(err, &mlp, 4, lenv, 0, 0, 0, LossDerivCROSS, 0, 0, 0, 0, 0.5, 0.5, 13, 9);
38
39 if (err[0]) {
40 printf("[Error]: %d\n", err[0]);
41 return 1;
42 }
43
44 // Initialize weights and biases with larger range, all layers with sigmoid
45 for (size_t i = 0; i < mlp->m_depth - 1; i++) {
46 mlp->m_act[i] = Act;
47 mlp->m_actderiv[i] = ActDeriv;
48 size_t outc = lenv[i + 1];
49 size_t inc = lenv[i];
50 // Initialize biases
51 for (size_t j = 0; j < outc; j++) {
52 mlp->m_bias[i * mlp->m_outc + j] = ((ae2f_float_t)rand() / RAND_MAX - 0.5) * 2.0; // [-1, 1]
53 }
54 // Initialize weights
55 for (size_t j = 0; j < inc * outc; j++) {
56 mlp->m_weight[i * mlp->m_outc * mlp->m_outc + j] = ((ae2f_float_t)rand() / RAND_MAX - 0.5) * 2.0; // [-1, 1]
57 }
58 }
59
60 // Initial predictions
61 for (size_t j = 0; j < 4; j++) {
62 mlp->Predict(err, inp[j], output);
63 if (err[0]) {
64 printf("[Error in Predict]: %d\n", err[0]);
66 return 1;
67 }
68 printf("Before train: %f %f -> %f (target: %f)\n", inp[j][0], inp[j][1], output[0], goal_xor[j]);
69 }
70
71 // Training loop with mini-batch (all 4 inputs per epoch)
72 for (size_t i = 0; i < 10000; i++) {
73 for (size_t j = 0; j < 4; j++) {
74 mlp->TrainAuto(err, inp[j], goal_xor + j);
75 if (err[0]) {
76 printf("[Error in TrainAutoStream]: %d\n", err[0]);
78 return 1 ;
79 }
80 }
81 if (i % 1000 == 0) {
82 printf("\nEpoch %zu:\n", i);
83 for (size_t j = 0; j < 4; j++) {
84 output[0] = 0;
85 mlp->Predict(err, inp[j], output);
86 if (err[0]) {
87 printf("[Error in PredictStream]: %d\n", err[0]);
89 return 1;
90 }
91 printf("\t%f %f -> %f (target: %f)\n", inp[j][0], inp[j][1], output[0], goal_xor[j]);
92 }
93 }
94 }
95
96 // Final predictions
97 printf("\nFinal results:\n");
98 for (size_t j = 0; j < 4; j++) {
99 output[0] = 0;
100 mlp->Predict(err, inp[j], output);
101 if (err[0]) {
102 printf("[Error in Predict]: %d\n", err[0]);
104 return 1;
105 }
106 printf("%zu, %f %f -> %f (target: %f)\n", j, inp[j][0], inp[j][1], output[0], goal_xor[j]);
107 }
108
110 return 0;
111}
ae2f_float ae2f_float_t
Definition Float.h:38
ae2f_float_t output[1]
ae2f_AnnMlp_t mlp
ae2f_err_t err[1]
Definition MlpTrainXOR.c:40
size_t lenv[]
Definition MlpTrainXOR.c:42
const ae2f_float_t goal_xor[4]
const ae2f_float_t inp[4][2]
int main()
ae2f_SHAREDEXPORT void ae2f_AnnMlpDel(ae2f_AnnMlp *restrict const block) noexcept
Definition Mlp.imp.c:40
ae2f_SHAREDEXPORT void ae2f_AnnMlpMk(ae2f_err_t *restrict const reterr, ae2f_AnnMlp *restrict *restrict const retmk, const size_t depth, const size_t *restrict const szvector, ae2f_opt size_t *restrict const szswap_opt, ae2f_opt ae2f_AnnAct_t **restrict const act, ae2f_opt ae2f_AnnAct_t **restrict const actderiv, ae2f_AnnLoss_t *const lossderiv, ae2f_opt ae2f_float_t *restrict const deltastream, ae2f_opt ae2f_float_t *restrict const outcache, ae2f_opt ae2f_float_t *restrict const weight, ae2f_opt ae2f_float_t *restrict const bias, ae2f_float_t const learningrate, ae2f_float_t const learningrate_bias, const size_t offset, const size_t extra) noexcept
Definition Mlp.imp.c:5
uint8_t ae2f_err_t
Informs that this number represents the error.
Definition errGlob.h:19