PyTorch 中的 mul

PyTorch 中的 mul
最新回答
快乐暗恋我

2021-04-22 03:18:28

PyTorch中的mul()函数用于执行元素级别的乘法运算,支持张量与张量、标量与张量、张量与标量以及标量与标量之间的操作,并可通过in-place操作直接修改输入张量。

核心功能与参数说明
  • 功能:对输入的两个参数(input和other)进行逐元素乘法运算,结果张量的每个元素为input和other对应位置元素的乘积。
  • 参数类型

    input:可为PyTorch张量或标量(int、float、complex、bool类型)。

    other:可为PyTorch张量或标量(类型同input)。

    out(可选):指定输出张量的存储位置,避免额外内存分配。

  • 形状兼容性

    若input和other均为张量,其形状需满足广播规则(如[3]与[2,3]可广播为[2,3])。

    若其中一者为标量,则标量会与另一张量的每个元素相乘。

操作类型与示例
  1. 张量与张量相乘输入张量形状需兼容,结果为逐元素乘积。

    import torchtensor1 = torch.tensor([9, 7, 6])tensor2 = torch.tensor([[4, -4, 3], [-2, 5, -5]])result = torch.mul(input=tensor1, other=tensor2)# 输出:tensor([[36, -28, 18], [-18, 35, -30]])
  2. 标量与张量相乘标量与张量每个元素相乘,结果形状与原张量一致。

    result = torch.mul(input=9, other=tensor2)# 输出:tensor([[36, -36, 27], [-18, 45, -45]])result = torch.mul(input=tensor1, other=4)# 输出:tensor([36, 28, 24])
  3. 标量与标量相乘直接返回两标量的乘积,结果为标量张量。

    result = torch.mul(input=9, other=4)# 输出:tensor(36)
  4. 支持的数据类型除整数外,mul()还支持浮点数、复数和布尔类型:

    浮点数示例:torch.mul(input=1.5, other=torch.tensor([2.0, 3.0]))

    复数示例:torch.mul(input=1+2j, other=torch.tensor([1j, 2j]))

    布尔类型示例:torch.mul(input=True, other=torch.tensor([True, False]))

in-place操作

通过添加下划线(_)直接修改输入张量,减少内存开销:

tensor1.mul_(other=tensor2) # tensor1被修改为乘积结果multiply()函数

multiply()是mul()的别名,功能完全相同,可互换使用:

result = torch.multiply(input=tensor1, other=4) # 等价于torch.mul()注意事项
  • 形状匹配:张量相乘时需确保形状兼容,否则会报错(如[3]与[4]无法广播)。
  • 数据类型一致性:若输入为不同类型(如int与float),结果会自动向上转换(如int * float → float)。
  • 性能优化:对大规模张量操作时,in-place操作(mul_)可提升效率,但会覆盖原数据,需谨慎使用。
总结

mul()是PyTorch中基础的元素级乘法函数,支持多种数据类型和形状组合,并通过in-place操作优化性能。其别名multiply()提供更直观的命名选择。使用时需确保参数形状兼容,并根据需求选择是否直接修改输入张量。