2021-04-22 03:18:28
PyTorch中的mul()函数用于执行元素级别的乘法运算,支持张量与张量、标量与张量、张量与标量以及标量与标量之间的操作,并可通过in-place操作直接修改输入张量。
核心功能与参数说明input:可为PyTorch张量或标量(int、float、complex、bool类型)。
other:可为PyTorch张量或标量(类型同input)。
out(可选):指定输出张量的存储位置,避免额外内存分配。
若input和other均为张量,其形状需满足广播规则(如[3]与[2,3]可广播为[2,3])。
若其中一者为标量,则标量会与另一张量的每个元素相乘。
张量与张量相乘输入张量形状需兼容,结果为逐元素乘积。
import torchtensor1 = torch.tensor([9, 7, 6])tensor2 = torch.tensor([[4, -4, 3], [-2, 5, -5]])result = torch.mul(input=tensor1, other=tensor2)# 输出:tensor([[36, -28, 18], [-18, 35, -30]])标量与张量相乘标量与张量每个元素相乘,结果形状与原张量一致。
result = torch.mul(input=9, other=tensor2)# 输出:tensor([[36, -36, 27], [-18, 45, -45]])result = torch.mul(input=tensor1, other=4)# 输出:tensor([36, 28, 24])标量与标量相乘直接返回两标量的乘积,结果为标量张量。
result = torch.mul(input=9, other=4)# 输出:tensor(36)支持的数据类型除整数外,mul()还支持浮点数、复数和布尔类型:
浮点数示例:torch.mul(input=1.5, other=torch.tensor([2.0, 3.0]))
复数示例:torch.mul(input=1+2j, other=torch.tensor([1j, 2j]))
布尔类型示例:torch.mul(input=True, other=torch.tensor([True, False]))
通过添加下划线(_)直接修改输入张量,减少内存开销:
tensor1.mul_(other=tensor2) # tensor1被修改为乘积结果multiply()函数multiply()是mul()的别名,功能完全相同,可互换使用:
result = torch.multiply(input=tensor1, other=4) # 等价于torch.mul()注意事项mul()是PyTorch中基础的元素级乘法函数,支持多种数据类型和形状组合,并通过in-place操作优化性能。其别名multiply()提供更直观的命名选择。使用时需确保参数形状兼容,并根据需求选择是否直接修改输入张量。