一、什么是Pytorch?

机器学习框架。

主要功能:自动微分计算、GPUs运算。

二、模型的评估

训练、验证、测试

三、加载数据

1.torch.Dataset

读取数据。

2.torch.Dataloader

加载数据batch。

四、自定义数据加载类

就是Dataloader。

五、Tensor

六、torch.nn

layer = nn.Linear # 线性层

.weight .bias

七、torch.optim

SGD。

梯度清零、反向传播、梯度更新。

八、训练

net.train() 训练

net.eval() + with torch.no_grad() 验证、测试

torch.save(net.state_dict(), path) 保存参数

x = torch.load(path)

net.load_state_dict(x) 加载参数

九、Colab

免费GPU。

code cell、text cell。

!用于指定指令。

!nvidia-smi