Pytorch中的forward理解
https://zhuanlan.zhihu.com/p/357021687
在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。
class Module(nn.Module):
def __init__(self):
super().__init__()
# ......
def forward(self, x):
# ......
return x
data = ...... # 输入数据
# 实例化一个对象
model = Module()
# 前向传播
model(data)
# 而不是使用下面的
# model.forward(data)
实际上model(data)
是等价于model.forward(data)
的,这是因为torch.nn.Module
类中使用了__call__
函数,这个函数将类的实例对象变为可调用对象,例如:
class Student:
def __call__(self, param):
print('I can called like a function')
print('传入参数的类型是:{} 值为: {}'.format(type(param), param))
res = self.forward(param)
return res
def forward(self, input_):
print('forward 函数被调用了')
print('in forward, 传入参数类型是:{} 值为: {}'.format(type(input_), input_))
return input_
a = Student()
input_param = a('data')
print("对象a传入的参数是:", input_param)
即__call__
函数中调用了forward
函数
文档信息
- 本文作者:焦逸凡
- 本文链接:https://ailovejinx.github.io/2023/03/02/blog-forward/
- 版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)