Pytorch中的backward()函数使用技巧

loss.backward()

Posted by Ellie on October 31, 2024

条件

为了确保梯度能够传递到自定义 encoder,需要满足以下条件:

  1. 模型集成到计算图中
    你的自定义 encoder 必须用 PyTorch 定义,并且其输出被纳入到计算损失的过程中。例如,假设你替换了文本编码器(text encoder),可以这样构造架构:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    # CLIP image encoder
    image_features = clip_model.get_image_features(inputs["pixel_values"].to(device))
    
    # Custom text encoder
    text_features = my_text_encoder(inputs['input_ids'].to(device))
    
    # 对比学习损失
    logits = image_features @ text_features.T  # 点积相似度
    labels = torch.arange(len(logits)).to(device)
    loss = torch.nn.CrossEntropyLoss()(logits, labels)
    
  2. 自定义 encoder 的参数可优化
    自定义 encoder 的参数必须在模型参数中注册,比如通过 torch.nn.Module 子类实现,确保其参数包含在优化器中。
    1
    2
    3
    4
    
    optimizer = torch.optim.Adam(
        list(clip_model.parameters()) + list(my_text_encoder.parameters()),
        lr=1e-4
    )
    
  3. 损失计算包括自定义 encoder 的输出
    如果对比学习的损失只依赖 CLIP 的编码器输出,而未包含自定义 encoder 的输出,那么梯度不会传递到自定义 encoder。

loss.backward() 的作用

  1. 自动求导机制
    PyTorch 的自动求导机制会跟踪所有张量操作,构建动态计算图。只要自定义 encoder 的输出参与了损失计算,它的参数梯度会被自动计算。

  2. 共享或分离计算图

    • 如果你的自定义 encoder 和 CLIP 的 encoder 是独立的模块,它们的计算图会被同时追踪。
    • 注意:如果在编码器之间使用了 detach(),梯度不会传递给被 detach 的部分。例如:
      1
      2
      
      # 这样会阻断梯度流向 CLIP 的编码器
      image_features = clip_model.get_image_features(inputs["pixel_values"].to(device)).detach()
      

示例代码

下面是一个完整的训练过程代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 假设使用 CLIP 的 image encoder 和自定义 text encoder
image_features = clip_model.get_image_features(inputs["pixel_values"].to(device))
text_features = my_text_encoder(inputs['input_ids'].to(device))

# 对比学习损失
logits_per_image = image_features @ text_features.T
labels = torch.arange(len(logits_per_image)).to(device)
loss = torch.nn.CrossEntropyLoss()(logits_per_image, labels)

# 反向传播
loss.backward()

# 更新参数
optimizer.step()
optimizer.zero_grad()

可能遇到的问题

  1. 梯度未更新
    如果你的自定义 encoder 的输出未参与损失计算,梯度不会更新。这时需要检查自定义 encoder 的输出是否正确连接到损失函数。

  2. 梯度爆炸或消失
    如果你的自定义 encoder 设计较为复杂,可能导致梯度爆炸或消失。可以尝试使用梯度裁剪(gradient clipping)或调整学习率。

  3. 模型参数未注册到优化器中
    如果你的自定义 encoder 的参数未包含在优化器中,训练时这些参数不会更新。检查优化器的参数是否包括自定义 encoder 的参数:

    1
    2
    
    for param_group in optimizer.param_groups:
        print(param_group['params'])