• 欢迎大家交换友链,可在https://www.stubbornhuang.com/申请友情链接进行友链交换申请!

  • 问题反馈可发送邮件到stubbornhuang@qq.com

  • 本站由于前段时间遭受到大量临时和国外邮箱注册,所以对可注册的邮箱类型进行了限制!

  • 在本站开通年度VIP,无限制下载本站资源和阅读本站文章

  • 工资「喂饱肚子」,副业「养活灵魂」!

  • 感谢大家访问本站,希望本站的内容可以帮助到大家!

  • 如果觉得本站的内容有帮助,可以考虑打赏博主哦!

  • 计算机图形学与计算几何经典必备书单整理,下载链接可参考:https://www.stubbornhuang.com/1256/

  • 本站会放置Google广告用于维持域名以及网站服务器费用。

MindSpore – LeNet5的MindSpore实现

MindSpore 发布于2022-02-23 阅读 4,129次 0次评论 0次点赞 本文共2661个字,阅读需要7分钟。

1 LeNet5MindSpore实现

MindSpore技术白皮书中LeNet5网络的MindSpore版本实现,与Pytorch和Tensorflow的版本相比可以让人更快的熟悉MindSpore的使用方式。

以下代码定义以及训练LeNet神经网络的过程。

# -*- coding: utf-8 -*-

import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.network.optim import Momentum
from mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset as de

class LeNet5(nn.Cell):
    """
    Lenet网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 定义所需要的运算
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

if __name__ == '__main__':
    ds = de.MnistDataset(dataset_dir="./MNIST_Data")
    ds = ds.batch(batch_size=64)
    network = LeNet5()
    loss = SoftmaxCrossEntropyWithLogits()
    optimizer = nn.Momentum(network.trainable_params(),learning_rate=0.1, momentum=0.9)
    model = Model(network, loss, optimizer)
    model.train(epoch=10, train_dataset=ds)

在上述代码的

import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.network.optim import Momentum
from mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset as de

的部分,导入了MindSpore的相关库和模块。


class LeNet5(nn.Cell):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2)
        self.flatten = P.Flatten()

    def construct(self,x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

以上代码部分定义了LeNet5的网络结果。__init__函数实例化了LeNet所用到的所有算子,construct函数定义了LeNet的计算逻辑。


ds = de.MnistDataset(dataset_dir="./MNIST_Data")
    ds = ds.batch(batch_size=64)

以上代码部分从Mnist数据集中读取数据,并生成一个迭代器ds用作训练的输入。


network = LeNet5()

将LeNet5类实例化为network。


loss = SoftmaxCrossEntropyWithLogits()
optimizer = nn.Momentum(network.trainable_params(),learning_rate=0.1, momentum=0.9)
model = Model(network, loss, optimizer)

使用SoftmaxCrossEntropyWithLogits计算损失loss,并使用momentum优化参数,最后使用定义的损失函数loss和优化器optimizer创建模型。


model.train(epoch=10, train_dataset=ds)

最后使用epoch控制训练迭代次数,调用模型的训练方法,并在每个eval_step对模型进行评估。

欢迎扫码关注我的微信公众号,及时获取文章更新

微信公众号二维码

本文作者:StubbornHuang

版权声明:本文为站长原创文章,如果转载请注明原文链接!

原文标题:MindSpore – LeNet5的MindSpore实现

原文链接:https://www.stubbornhuang.com/1979/

发布于:2022年02月23日 15:40:47

修改于:2023年06月26日 20:37:00

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。

文章末尾
上一篇
WordPress - 修复WordPress升级5.1之后版本评论回复按钮失效不跳转以及不弹出评论框的问题
WordPress
下一篇
WordPress - 修复Markdown编辑器插件WP-Editor.md在插入php代码块后代码中的$符号无法正常显示的问题
WordPress
当前分类随机文章推荐

发表评论

您必须 [ 登录 ] 才能发表留言!

关注我们的公众号

微信公众号