Featured image of post 两百行代码搓一个原始神经网络

两百行代码搓一个原始神经网络

这个文章主要还是因为这周研究了挺久的AI相关的东西,然后休闲时刷视频发现了这个

We’re Building Computers Wrong

闲的没事尝试复现了一下其中的第一个demo,这也是是一个很简单但很有意思的demo 如果你闲的蛋疼可以考虑看一下视频,就能大致知到原理了,当然你也可以考虑跳过,我在下面简述一下

另外完整的代码我已经放在 github 了,点一下我头像下面的连接就可以找到

基本原理

如果你看了视频,这一步可以跳过

1958年,Frank Rosenblatt提出了一种模仿神经发射信号方式的机器 Mark I 感知机,Mark I 的输入层是一个 20x20 的感光单元矩阵,共400个光电探测器,用于"看"图像,并将光学信号转化为电信号,神经元之间连接的"权重"被编码在电位器中 回到神经细胞(这里不是严格生物意义上的神经细胞)上,一个神经细胞可以有两种激活状态,激活或者不激活,而神经细胞之间可以有连接,如果一个神经细胞连接数个神经细胞,这个细胞的激活状态由其他神经细胞决定,这些神经细胞的激活与不激活的状态通过连接传递给这个神经细胞,不同的神经细胞之间的连接强度不同,部分连接可以起到刺激作用,部分为抑制作用,当整合后的信号超过某个阈值,这个神经细胞就会激活

我们可以把这里的连接强度视作一种权重,不同神经细胞的输入乘以对应的权重的和是否能够达到某个阈值决定下一个神经细胞的状态,因此我们可以很简单的模拟出神经细胞的运作方式,复现一下 Mark I 感知机 有可能你看的一头雾水,直接上实战演示

搓一个原始神经网络

这里的原理并不是特别复杂,因此你甚至可以使用c语言搓一个出来,所以我们就直接掏出巨大的c语言,直接做一个简单的demo使得可以分辨方形和圆形

基本处理

为了方便接下来的代码编写,我们先搓一个common.h文件

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
#ifndef COMMON_H
#define COMMON_H

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define WIDTH  20         // 图形宽度
#define HEIGHT 20         // 图形高度
#define BIAS   20.0f      // 分类阈值

神经元输入

首先我们需要一个 20*20 的"感光单元矩阵",当然我们使用一个 20*20 char数组就行了,数组的每个索引处可以是一个数,这里我们为了方便,只采用0和1,这20*20的每个神经元的权重我们同样采用一个数组保存,但是这个没办法缩减,使用一个float数组

1
2
typedef unsigned char Image[HEIGHT][WIDTH];
typedef float Weights[HEIGHT][WIDTH];

图像的保存与加载

这里为了简单方便的表示,我们就不使用.png .jpg之类的复杂的图片,我们使用一个.bin文件,每个字节代表一个像素,每个图片20*20=400个字节

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
static inline int load_image(const char *filepath, Image img) {
    FILE *f = fopen(filepath, "rb");
    if (!f) return -1;
    size_t n = fread(img, sizeof(unsigned char), HEIGHT * WIDTH, f);
    fclose(f);
    return (n == HEIGHT * WIDTH) ? 0 : -1;
}

static inline void save_weights(const Weights w, const char *path) {
    FILE *f = fopen(path, "wb");
    if (!f) {
        fprintf(stderr, "ERROR: cannot write weights to %s\n", path);
        exit(1);
    }
    fwrite(w, sizeof(float), HEIGHT * WIDTH, f);
    fclose(f);
}

连接权重的加载

同样,我们可以给将权重保存到.bin文件,并从此加载

1
2
3
4
5
6
7
8
9
static inline void load_weights(Weights w, const char *path) {
    FILE *f = fopen(path, "rb");
    if (!f) {
        fprintf(stderr, "ERROR: cannot load weights from %s\n", path);
        exit(1);
    }
    fread(w, sizeof(float), HEIGHT * WIDTH, f);
    fclose(f);
}

加权和计算

我们可以直接将每个索引与其对应的权重相乘并计算其和

1
2
3
4
5
6
7
static inline float feed_forward(const Image img, const Weights w) {
    float sum = 0.0f;
    for (int y = 0; y < HEIGHT; ++y)
        for (int x = 0; x < WIDTH; ++x)
            sum += img[y][x] * w[y][x];
    return sum;
}

common.h

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#ifndef COMMON_H
#define COMMON_H

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define WIDTH  20
#define HEIGHT 20
#define BIAS   20.0f      // 分类阈值

// 图像数据类型(每个像素 0 或 1)
typedef unsigned char Image[HEIGHT][WIDTH];

// 权重数据类型(浮点数)
typedef float Weights[HEIGHT][WIDTH];

// 读取 .bin 图像文件
static inline int load_image(const char *filepath, Image img) {
    FILE *f = fopen(filepath, "rb");
    if (!f) return -1;
    size_t n = fread(img, sizeof(unsigned char), HEIGHT * WIDTH, f);
    fclose(f);
    return (n == HEIGHT * WIDTH) ? 0 : -1;
}

// 保存权重到二进制文件
static inline void save_weights(const Weights w, const char *path) {
    FILE *f = fopen(path, "wb");
    if (!f) {
        fprintf(stderr, "ERROR: cannot write weights to %s\n", path);
        exit(1);
    }
    fwrite(w, sizeof(float), HEIGHT * WIDTH, f);
    fclose(f);
}

// 加载权重文件
static inline void load_weights(Weights w, const char *path) {
    FILE *f = fopen(path, "rb");
    if (!f) {
        fprintf(stderr, "ERROR: cannot load weights from %s\n", path);
        exit(1);
    }
    fread(w, sizeof(float), HEIGHT * WIDTH, f);
    fclose(f);
}

// 前向传播:计算加权和
static inline float feed_forward(const Image img, const Weights w) {
    float sum = 0.0f;
    for (int y = 0; y < HEIGHT; ++y)
        for (int x = 0; x < WIDTH; ++x)
            sum += img[y][x] * w[y][x];
    return sum;
}

// 更新规则:权重 += 输入(用于圆形误判为矩形时)
static inline void add_input(Weights w, const Image img) {
    for (int y = 0; y < HEIGHT; ++y)
        for (int x = 0; x < WIDTH; ++x)
            w[y][x] += img[y][x];
}

// 更新规则:权重 -= 输入(用于矩形误判为圆形时)
static inline void sub_input(Weights w, const Image img) {
    for (int y = 0; y < HEIGHT; ++y)
        for (int x = 0; x < WIDTH; ++x)
            w[y][x] -= img[y][x];
}

#endif // COMMON_H

预测模型

将上一步准备的几个函数和权重文件以及图片文件捏合起来就可以做到预测了 我们设定一个BIAS阈值,如果大于则表示图形是一个矩形,小于则是圆,所以可以得到以下代码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int main(int argc, char *argv[]) {
    if (argc != 2) {
        fprintf(stderr, "Usage: %s <image.bin>\n", argv[0]);
        return 1;
    }

    Weights weights;
    load_weights(weights, "weights.bin");

    Image img;
    if (load_image(argv[1], img) != 0) {
        fprintf(stderr, "ERROR: cannot load image %s\n", argv[1]);
        return 1;
    }

    float out = feed_forward(img, weights);
    printf("Weighted sum = %f\n", out);
    if (out > BIAS)
        printf("Prediction: RECTANGLE\n");
    else if (out < BIAS)
        printf("Prediction: CIRCLE\n");
    else
        printf("Prediction: UNCERTAIN (exactly at threshold)\n");

    return 0;
}

训练模型

光有预测用的工具,没有对应的权重文件怎么办,我们还需要自己手动训练

训练素材

如果手搓数百个素材实在是太不健康了,此时就应该自己写一个批量生成素材的小工具

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <errno.h>
#include <sys/stat.h>

#define WIDTH       20
#define HEIGHT      20
#define SAMPLE_SIZE 500          // 每类样本数量(矩形和圆形各生成这么多)
#define DATA_DIR    "./train_data"
#define RANDOM_SEED 42          // 固定随机种子,使样本可复现

static inline int rand_range(int low, int high) {
    return rand() % (high - low) + low;
}

static void ensure_dir(const char *path) {
    struct stat st;
    if (stat(path, &st) == -1) {
        if (mkdir(path, 0755) == -1) {
            fprintf(stderr, "ERROR: cannot create directory %s: %s\n",
                    path, strerror(errno));
            exit(1);
        }
    }
}

void generate_random_rect(unsigned char img[HEIGHT][WIDTH]) {
    memset(img, 0, HEIGHT * WIDTH * sizeof(unsigned char));
    int x = rand_range(0, WIDTH);
    int y = rand_range(0, HEIGHT);
    int w = rand_range(1, WIDTH - x + 1);
    int h = rand_range(1, HEIGHT - y + 1);

    for (int dy = 0; dy < h; ++dy) {
        for (int dx = 0; dx < w; ++dx) {
            img[y + dy][x + dx] = 1;
        }
    }
}

void generate_random_circle(unsigned char img[HEIGHT][WIDTH]) {
    memset(img, 0, HEIGHT * WIDTH * sizeof(unsigned char));
    int cx = rand_range(0, WIDTH);
    int cy = rand_range(0, HEIGHT);
    int max_r = cx;
    if (cy < max_r) max_r = cy;
    if (WIDTH - 1 - cx < max_r) max_r = WIDTH - 1 - cx;
    if (HEIGHT - 1 - cy < max_r) max_r = HEIGHT - 1 - cy;
    if (max_r < 1) max_r = 1;

    int r = rand_range(1, max_r + 1);
    for (int y = 0; y < HEIGHT; ++y) {
        for (int x = 0; x < WIDTH; ++x) {
            int dx = x - cx;
            int dy = y - cy;
            if (dx * dx + dy * dy <= r * r) {
                img[y][x] = 1;
            }
        }
    }
}

void save_image(const unsigned char img[HEIGHT][WIDTH], const char *filename) {
    FILE *f = fopen(filename, "wb");
    if (!f) {
        fprintf(stderr, "ERROR: cannot open %s for writing: %s\n",
                filename, strerror(errno));
        exit(1);
    }
    // 直接写入 HEIGHT*WIDTH 个字节,每个字节是0或1
    fwrite(img, sizeof(unsigned char), HEIGHT * WIDTH, f);
    fclose(f);
}

void generate_all_samples(void) {
    unsigned char img[HEIGHT][WIDTH];
    for (int i = 0; i < SAMPLE_SIZE; ++i) {
        generate_random_rect(img);
        char filename[256];
        snprintf(filename, sizeof(filename), "%s/rect_%03d.bin", DATA_DIR, i);
        save_image(img, filename);
        printf("[INFO] generated %s\n", filename);
    }
    for (int i = 0; i < SAMPLE_SIZE; ++i) {
        generate_random_circle(img);
        char filename[256];
        snprintf(filename, sizeof(filename), "%s/circle_%03d.bin", DATA_DIR, i);
        save_image(img, filename);
        printf("[INFO] generated %s\n", filename);
    }
}

int main(void) {
    srand(RANDOM_SEED);
    ensure_dir(DATA_DIR);
    generate_all_samples();

    printf("[INFO] Done. Generated %d rectangles and %d circles in %s\n",
           SAMPLE_SIZE, SAMPLE_SIZE, DATA_DIR);
    return 0;
}

按照这里,批量搓了两种素材各500分,当然你要是想检查一下素材也可以再整一个检查的工具

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>

#define WIDTH  20
#define HEIGHT 20
#define DATA_DIR "./train_data"

// 打印图像:'#' 表示1(形状部分),'.' 表示0(背景)
static void print_image(const unsigned char img[HEIGHT][WIDTH]) {
    printf("+");
    for (int x = 0; x < WIDTH; ++x)
        printf("-");
    printf("+\n");

    for (int y = 0; y < HEIGHT; ++y) {
        printf("|");
        for (int x = 0; x < WIDTH; ++x) {
            putchar(img[y][x] ? '#' : '.');
        }
        printf("|\n");
    }

    printf("+");
    for (int x = 0; x < WIDTH; ++x)
        printf("-");
    printf("+\n");
}

int main(int argc, char *argv[]) {
    if (argc != 2) {
        fprintf(stderr, "Usage: %s <filename>\n", argv[0]);
        fprintf(stderr, "Example: %s rect_000.bin\n", argv[0]);
        return 1;
    }

    const char *filename = argv[1];
    char fullpath[512];

    if (strncmp(filename, DATA_DIR, strlen(DATA_DIR)) == 0) {
        snprintf(fullpath, sizeof(fullpath), "%s", filename);
    } else {
        snprintf(fullpath, sizeof(fullpath), "%s/%s", DATA_DIR, filename);
    }

    FILE *f = fopen(fullpath, "rb");
    if (!f) {
        fprintf(stderr, "Error: cannot open %s: %s\n", fullpath, strerror(errno));
        return 1;
    }

    unsigned char img[HEIGHT][WIDTH];
    size_t bytes_read = fread(img, sizeof(unsigned char), HEIGHT * WIDTH, f);
    fclose(f);

    if (bytes_read != HEIGHT * WIDTH) {
        fprintf(stderr, "Error: file size mismatch (expected %d bytes, got %zu)\n",
                HEIGHT * WIDTH, bytes_read);
        return 1;
    }

    printf("File: %s\n", fullpath);
    print_image(img);

    return 0;
}

训练权重

我们这里采用错误驱动训练,这里粘贴一段解释 alt text

矩形和圆形在像素空间并不是线性可分的(因为形状复杂),但在这个二值化、固定尺寸的表示下,它们实际上是线性可分的

感知机收敛定理指出:如果训练数据线性可分,那么经过有限次错误更新,算法会找到一个将两类完全分开的权重向量, 程序在每一轮完整遍历所有样本,当某轮没有错误时停止,因此最终会收敛

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include "common.h"

#define SAMPLE_SIZE   500
#define TRAIN_PASSES  4000
#define DATA_DIR      "./train_data"

int main(void) {
    Image img;
    Weights weights = {0};
    char path[256];

    char rect_files[SAMPLE_SIZE][256];
    char circle_files[SAMPLE_SIZE][256];
    for (int i = 0; i < SAMPLE_SIZE; ++i) {
        snprintf(rect_files[i], sizeof(rect_files[i]), "%s/rect_%03d.bin", DATA_DIR, i);
        snprintf(circle_files[i], sizeof(circle_files[i]), "%s/circle_%03d.bin", DATA_DIR, i);
    }

    printf("Start training...\n");

    for (int pass = 0; pass < TRAIN_PASSES; ++pass) {
        int adjusted = 0;

        // 训练矩形:期望输出 > BIAS
        for (int i = 0; i < SAMPLE_SIZE; ++i) {
            if (load_image(rect_files[i], img) != 0) {
                fprintf(stderr, "ERROR: cannot load %s\n", rect_files[i]);
                return 1;
            }
            float out = feed_forward(img, weights);
            if (out <= BIAS) {          // 误判为圆形 → 需要增大输出
                add_input(weights, img); // 修正:加输入
                adjusted++;
            }
        }

        // 训练圆形:期望输出 < BIAS
        for (int i = 0; i < SAMPLE_SIZE; ++i) {
            if (load_image(circle_files[i], img) != 0) {
                fprintf(stderr, "ERROR: cannot load %s\n", circle_files[i]);
                return 1;
            }
            float out = feed_forward(img, weights);
            if (out >= BIAS) {          // 误判为矩形 → 需要减小输出
                sub_input(weights, img); // 修正:减输入
                adjusted++;
            }
        }

        printf("Pass %3d: adjusted = %d\n", pass, adjusted);
        if (adjusted == 0) {
            printf("Converged at pass %d.\n", pass);
            break;
        }
    }

    save_weights(weights, "weights.bin");
    printf("Training finished. Weights saved to weights.bin\n");

    return 0;
}

测试一下

我们手动准备两个样本

demo1

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
+--------------------+
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|............#######.|
|............#######.|
|............#######.|
|............#######.|
|............#######.|
|............#######.|
|............#######.|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
+--------------------+

测试一下输出

1
2
Weighted sum = 238.000000
Prediction: RECTANGLE

看样子没有问题

demo2

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
+--------------------+
|....................|
|...............#....|
|.............#####..|
|.............#####..|
|............#######.|
|.............#####..|
|.............#####..|
|...............#....|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
|....................|
+--------------------+

康康输出

1
2
Weighted sum = -151.000000
Prediction: CIRCLE

还是比较成功,当然因为训练的原因,并不能所有的素材的非常精确的识别,还需要继续修改训练,修改BIAS和学习率之类的,只要样本和训练得当就能整一个出来

Licensed under CC BY-NC-SA 4.0