目录
PyTorch 是由 Facebook(现 Meta)开发的开源深度学习框架,以动态计算图、简洁 API 和强大的 Python 集成著称,是科研和工业界实现深度学习模型的主流工具之一。
一、核心认知(1 天):先搞懂 “为什么学” 和 “基础概念”
1. 核心优势(明确学习价值)
- 动态图优先:代码编写与 Python 逻辑一致(如写
if
/for
直接控制流程),调试直观,适合快速迭代。 - API 简洁:贴近 Python 语法,入门门槛低于 TensorFlow,新手易上手。
- 生态丰富:配套工具链完善(数据处理
TorchVision
、自然语言Hugging Face Transformers
等)。
2. 必学基础概念(不纠结细节,先建立认知)
概念 | 作用 | 类比 |
---|---|---|
Tensor(张量) | 核心数据结构,可理解为 “高维数组”(0 维 = 标量、1 维 = 向量、2 维 = 矩阵) | 相当于 NumPy 数组,但支持 GPU 加速 |
Autograd(自动求导) | 自动计算张量的梯度,是反向传播的核心 | 不用手动写导数公式,框架帮你算 |
Module(模块) | 封装神经网络层的类,方便构建 / 管理模型 | 像 “乐高积木”,每个 Module 是一块积木,拼起来就是模型 |
GPU 加速 | 用 cuda() 把 Tensor / 模型移到 GPU,大幅提升计算速度 |
用 “超级计算机” 代替 “普通电脑” 跑任务 |
二、基础实操(3 天):动手写代码,掌握核心工具
目标:能用 PyTorch 实现 “数据加载→模型定义→训练→预测” 的完整流程,推荐用 Google Colab(免装环境,直接用 GPU)或本地 Jupyter Notebook。
Day 1:Tensor 与 Autograd(核心数据操作)
-
Tensor 基本操作(类比 NumPy,重点记差异)
- 创建:
torch.tensor()
(手动传值)、torch.zeros()
/torch.ones()
(全 0 / 全 1)、torch.randn()
(正态分布)。 - 设备切换:
x = x.cuda()
(移到 GPU)、x = x.cpu()
(移回 CPU)。 - 常用运算:
x + y
(元素加)、x @ y
(矩阵乘)、x.view()
(改变形状,类似reshape
)。
- 创建:
-
Autograd 核心逻辑
- 开启求导:创建 Tensor 时加
requires_grad=True
(如x = torch.tensor([2.0], requires_grad=True)
)。 - 计算梯度:先算损失(如
y = x**2
),再执行y.backward()
,梯度会存在x.grad
中。 - 禁用求导:用
with torch.no_grad():
包裹(测试 / 推理时用,避免浪费资源)。
- 开启求导:创建 Tensor 时加
练习:写一段代码,计算y = 3x² + 2x + 1
在x=5
处的导数(答案应为 32)。
Day 2:数据加载与模型定义
-
数据加载(Torch.utils.data)
- 核心类:
Dataset
(自定义数据集,需实现__getitem__
方法)、DataLoader
(批量加载,支持打乱、多线程)。 - 示例:用
TorchVision.datasets.MNIST
加载手写数字数据集,用DataLoader
按批次读取。
- 核心类:
-
模型定义(nn.Module)
- 步骤:1. 继承
nn.Module
;2. 在__init__
中定义层(如nn.Linear
全连接层、nn.Conv2d
卷积层);3. 实现forward
方法(定义前向传播流程)。 - 示例:定义一个简单的全连接网络(输入 784→隐藏层 256→输出 10,用于 MNIST 分类)。
- 步骤:1. 继承
Day 3:训练与优化(核心流程)
-
核心组件
- 损失函数:分类用
nn.CrossEntropyLoss()
,回归用nn.MSELoss()
。 - 优化器:
torch.optim.SGD()
(随机梯度下降)、torch.optim.Adam()
(常用,收敛快)。
- 损失函数:分类用
-
训练循环模板(必背)python运行12345678910111213141516171819202122232425262728<span class="token comment"># 1. 初始化模型、损失函数、优化器</span>model <span class="token operator">=</span> MyModel<span class="token punctuation">(</span><span class="token punctuation">)</span>criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>CrossEntropyLoss<span class="token punctuation">(</span><span class="token punctuation">)</span>optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr<span class="token operator">=</span><span class="token number">1e-3</span><span class="token punctuation">)</span><span class="token comment"># 2. 训练循环(多轮 epoch)</span><span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># 训练5轮</span>model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 切换到训练模式(启用 dropout/batchnorm)</span><span class="token keyword">for</span> images<span class="token punctuation">,</span> labels <span class="token keyword">in</span> train_loader<span class="token punctuation">:</span> <span class="token comment"># 按批次取数据</span><span class="token comment"># 前向传播</span>outputs <span class="token operator">=</span> model<span class="token punctuation">(</span>images<span class="token punctuation">)</span>loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> labels<span class="token punctuation">)</span><span class="token comment"># 反向传播 + 优化</span>optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 清空上一轮梯度</span>loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 计算梯度</span>optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 更新参数</span><span class="token comment"># 测试(可选)</span>model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 切换到评估模式(禁用 dropout/batchnorm)</span><span class="token keyword">with</span> torch<span class="token punctuation">.</span>no_grad<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>total_correct <span class="token operator">=</span> <span class="token number">0</span><span class="token keyword">for</span> images<span class="token punctuation">,</span> labels <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span>outputs <span class="token operator">=</span> model<span class="token punctuation">(</span>images<span class="token punctuation">)</span>_<span class="token punctuation">,</span> preds <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>total_correct <span class="token operator">+=</span> <span class="token punctuation">(</span>preds <span class="token operator">==</span> labels<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"Epoch </span><span class="token interpolation"><span class="token punctuation">{</span>epoch<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">}</span></span><span class="token string">, 测试准确率: </span><span class="token interpolation"><span class="token punctuation">{</span>total_correct<span class="token operator">/</span><span class="token builtin">len</span><span class="token punctuation">(</span>test_loader<span class="token punctuation">.</span>dataset<span class="token punctuation">)</span><span class="token punctuation">:</span><span class="token format-spec">.2f</span><span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
三、实战巩固(2 天):用小项目落地,强化记忆
选择 1-2 个简单项目,重点是 “复现流程”,不追求模型复杂度:
项目 1:MNIST 手写数字分类(入门首选)
- 目标:用全连接网络或简单卷积网络(
nn.Conv2d
+nn.MaxPool2d
)实现,目标测试准确率≥97%。 - 重点:熟悉数据加载、模型定义、训练循环的完整链路。
项目 2:简单回归任务(如房价预测简化版)
- 目标:用
nn.Linear
构建回归模型,预测输入特征对应的连续值(如用随机生成的 “面积、房间数” 预测 “房价”)。 - 重点:理解回归任务与分类任务的差异(损失函数、评价指标)。
四、查漏补缺(1 天):解决高频问题,补充关键细节
-
高频报错与解决方案
RuntimeError: Expected object of device type cuda but got device type cpu
:Tensor 和模型不在同一设备(统一用cuda()
或cpu()
)。AttributeError: 'NoneType' object has no attribute 'grad'
:没加requires_grad=True
,或没执行backward()
。- 梯度爆炸 / 消失:初期可忽略,先保证流程跑通,后续再学
Gradient Clipping
等技巧。
-
关键细节补充
model.train()
vsmodel.eval()
:训练时用前者(启用 BatchNorm 更新、Dropout 随机失活),测试时用后者(固定 BatchNorm 参数、关闭 Dropout)。- 模型保存与加载:
torch.save(model.state_dict(), "model.pth")
(保存参数)、model.load_state_dict(torch.load("model.pth"))
(加载参数)。
五、学习资源推荐(高效避坑)
- 官方教程:PyTorch Quickstart(最权威,代码可直接复制运行)。
- 视频教程:B 站 “跟李沐学 AI” 的《PyTorch 实战》(前 5 集足够,讲得细且贴近实战)。
- 工具:遇到 API 疑问直接查 PyTorch 官方文档(搜索关键词如 “torch.nn.Linear”)。
关键提醒
- 一周内目标是 “能用 PyTorch 跑通基础流程”,而非 “精通所有细节”(如分布式训练、自定义反向传播可后续学)。
- 代码一定要自己敲,不要只看 —— 哪怕复制后改参数(如把学习率从 1e-3 改成 1e-4),也能加深理解。
- 遇到报错先自己搜(关键词 +“PyTorch”),大部分新手问题都有成熟解决方案。