1 LeNet5MindSpore实现

MindSpore技术白皮书中LeNet5网络的MindSpore版本实现,与Pytorch和Tensorflow的版本相比可以让人更快的熟悉MindSpore的使用方式。

以下代码定义以及训练LeNet神经网络的过程。

# -*- coding: utf-8 -*-

import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.network.optim import Momentum
from mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset as de

class LeNet5(nn.Cell):
    """
    Lenet网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 定义所需要的运算
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

if __name__ == '__main__':
    ds = de.MnistDataset(dataset_dir="./MNIST_Data")
    ds = ds.batch(batch_size=64)
    network = LeNet5()
    loss = SoftmaxCrossEntropyWithLogits()
    optimizer = nn.Momentum(network.trainable_params(),learning_rate=0.1, momentum=0.9)
    model = Model(network, loss, optimizer)
    model.train(epoch=10, train_dataset=ds)

在上述代码的

import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.network.optim import Momentum
from mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset as de

的部分,导入了MindSpore的相关库和模块。


class LeNet5(nn.Cell):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2)
        self.flatten = P.Flatten()

    def construct(self,x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

以上代码部分定义了LeNet5的网络结果。__init__函数实例化了LeNet所用到的所有算子,construct函数定义了LeNet的计算逻辑。


ds = de.MnistDataset(dataset_dir="./MNIST_Data")
    ds = ds.batch(batch_size=64)

以上代码部分从Mnist数据集中读取数据,并生成一个迭代器ds用作训练的输入。


network = LeNet5()

将LeNet5类实例化为network。


loss = SoftmaxCrossEntropyWithLogits()
optimizer = nn.Momentum(network.trainable_params(),learning_rate=0.1, momentum=0.9)
model = Model(network, loss, optimizer)

使用SoftmaxCrossEntropyWithLogits计算损失loss,并使用momentum优化参数,最后使用定义的损失函数loss和优化器optimizer创建模型。


model.train(epoch=10, train_dataset=ds)

最后使用epoch控制训练迭代次数,调用模型的训练方法,并在每个eval_step对模型进行评估。