How can I access layers in a pytorch module by index?(如何通过索引访问 pytorch 模块中的层?)
问题描述
我正在尝试编写一个具有多个层的 pytorch 模块.由于我需要中间输出,我不能像往常一样将它们全部放入 Sequantial 中.另一方面,由于有很多层,我的想法是将层放在一个列表中并在循环中通过索引访问它们.下面描述我想要实现的目标:
I am trying to write a pytorch module with multiple layers. Since I need the intermediate outputs I cannot put them all in a Sequantial as usual. On the other hand, since there are many layers, what I have in mind is to put the layers in a list and access them by index in a loop. Below describe what I am trying to achieve:
import torch
import torch.nn as nn
import torch.optim as optim
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer_list = []
self.layer_list.append(nn.Linear(2,3))
self.layer_list.append(nn.Linear(3,4))
self.layer_list.append(nn.Linear(4,5))
def forward(self, x):
res_list = [x]
for i in range(len(self.layer_list)):
res_list.append(self.layer_list[i](res_list[-1]))
return res_list
model = MyModel()
x = torch.randn(4,2)
y = model(x)
print(y)
optimizer = optim.Adam(model.parameters())
forward 方法工作正常,但是当我想设置优化器时,程序说
The forward method works fine, but when I want to set an optimizer the program says
ValueError: optimizer got an empty parameter list
列表中的图层似乎没有在这里注册.我能做什么?
It appears that the layers in the list are not registered here. What can I do?
推荐答案
如果你把你的图层放在 python 列表中,pytorch 不会正确注册它们.您必须使用 ModuleList (https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).
If you put your layers in a python list, pytorch does not register them correctly. You have to do so using ModuleList (https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).
ModuleList 可以像常规 Python 列表一样被索引,但它包含的模块已正确注册,并且所有模块方法都可以看到.
ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.
您的代码应该类似于:
import torch
import torch.nn as nn
import torch.optim as optim
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer_list = nn.ModuleList() # << the only changed line! <<
self.layer_list.append(nn.Linear(2,3))
self.layer_list.append(nn.Linear(3,4))
self.layer_list.append(nn.Linear(4,5))
def forward(self, x):
res_list = [x]
for i in range(len(self.layer_list)):
res_list.append(self.layer_list[i](res_list[-1]))
return res_list
通过使用 ModuleList,您可以确保所有层都在计算图中注册.
By using ModuleList you make sure all layers are registered in the computational graph.
还有一个 ModuleDict 如果你想按名称索引你的图层,你可以使用它.您可以在此处查看 pytorch 的容器:https://pytorch.org/docs/master/nn.html#containers
There is also a ModuleDict that you can use if you want to index your layers by name. You can check pytorch's containers here: https://pytorch.org/docs/master/nn.html#containers
这篇关于如何通过索引访问 pytorch 模块中的层?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:如何通过索引访问 pytorch 模块中的层?
基础教程推荐
- Python,确定字符串是否应转换为 Int 或 Float 2022-01-01
- 在 Python 中将货币解析为数字 2022-01-01
- 在 Django Admin 中使用内联 OneToOneField 2022-01-01
- matplotlib 设置 yaxis 标签大小 2022-01-01
- 对多索引数据帧的列进行排序 2022-01-01
- kivy 应用程序中的一个简单网页作为小部件 2022-01-01
- Python 中是否有任何支持将长字符串转储为块文字或折叠块的 yaml 库? 2022-01-01
- 比较两个文本文件以找出差异并将它们输出到新的文本文件 2022-01-01
- Kivy 使用 opencv.调整图像大小 2022-01-01
- 究竟什么是“容器"?在蟒蛇?(以及所有的 python 容器类型是什么?) 2022-01-01
