文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
注:本文为李沐大神的《动手学深度学习》的课程笔记!
创建数据集
1 | # 导入mxnet |
数据展示
1 | %matplotlib inline |
数据读取
1 | # 训练时的批数据大小 |
1 | # 查看数据 |
[[-2.11255503 0.61242002]
[ 2.18546367 -0.48856559]
[ 0.91085583 0.38985687]
[-0.56097323 1.44421673]
[ 0.31765923 -1.75729597]
[-0.57738042 2.03963804]
[-0.91808975 0.64181799]
[-0.20269176 0.21012937]
[-0.22549874 0.19895147]
[ 1.42844415 0.06982213]]
<NDArray 10x2 @cpu(0)>
[ -2.11691356 10.22533131 4.70613146 -1.82755637 10.82125568
-3.88111711 0.17608714 3.07074499 3.06542921 6.82972908]
<NDArray 10 @cpu(0)>
定义模型
1 | # 定义一个空的模型 |
初始化模型参数
1 | net.initialize() |
定义损失函数
1 | square_loss = gluon.loss.L2Loss() |
优化
1 | trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01}) |
训练
1 | # 定义训练的迭代周期 |
Epoch 0, average loss: 7.403182
Epoch 1, average loss: 0.854247
Epoch 2, average loss: 0.099864
Epoch 3, average loss: 0.011887
Epoch 4, average loss: 0.001479
代码地址
https://github.com/SnailTyan/gluon-practice-code
参考资料
ArrayDataset
https://mxnet.incubator.apache.org/api/python/gluon/data.html#mxnet.gluon.data.ArrayDatasetTrainer
https://mxnet.incubator.apache.org/api/python/gluon/gluon.html?highlight=trainer#mxnet.gluon.Trainer