文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
本文主要是关于PyTorch的一些用法。
1 | import torch |
1 | # 绘制数据散点图 |
1 | # 定义分类网络 |
1 | # 定义网络 |
Net (
(hidden): Linear (2 -> 10)
(prediction): Linear (10 -> 2)
)
1 | # 定义优化方法 |
1 | # torch.max用法 |
-1.8524 -1.0491 0.5382 -0.5129
0.1233 -0.1821 2.1519 -1.4547
-1.0267 0.2644 -0.8832 -0.2647
0.3944 -1.2512 -0.1158 0.5071
[torch.FloatTensor of size 4x4]
(
0.5382
2.1519
0.2644
0.5071
[torch.FloatTensor of size 4]
,
2
2
1
3
[torch.LongTensor of size 4]
)