Pytorch中的矩阵乘法

一些矩阵乘法

参考: https://blog.csdn.net/qq_42388742/article/details/120474434

torch.bmm

两个batch做矩阵乘法,是两个三维张量相乘, 两个输入tensor维度是$(b\times n\times m)$和$(b\times m\times p)$, 第一维b代表batch size,输出为$(b\times n \times p)$。

torch.mm

mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是$(n\times m)$和$(m\times p)$。

文档信息

Search

    Table of Contents