对Pytorch中forward的理解

2023/03/02 Pytorch 科研 共 766 字,约 3 分钟

Pytorch中的forward理解

https://zhuanlan.zhihu.com/p/357021687

在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

class Module(nn.Module):
    def __init__(self):
        super().__init__()
        # ......

    def forward(self, x):
        # ......
        return x


data = ......  # 输入数据

# 实例化一个对象
model = Module()

# 前向传播
model(data)

# 而不是使用下面的
# model.forward(data)  

实际上model(data)是等价于model.forward(data)的,这是因为torch.nn.Module类中使用了__call__函数,这个函数将类的实例对象变为可调用对象,例如:

class Student:
    def __call__(self, param):
        print('I can called like a function')
        print('传入参数的类型是:{}   值为: {}'.format(type(param), param))

        res = self.forward(param)
        return res

    def forward(self, input_):
        print('forward 函数被调用了')

        print('in  forward, 传入参数类型是:{}  值为: {}'.format(type(input_), input_))
        return input_


a = Student()

input_param = a('data')
print("对象a传入的参数是:", input_param)

__call__函数中调用了forward函数

文档信息

Search

    Table of Contents