x是输入张量
dim指定维度
max表示取最大值
import torch
if __name__ == '__main__':
x = torch.randn([1, 3, 4, 4]).cuda()
mask = (x == x.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
result = torch.mul(mask, x)
print(x)
print(mask)
print(result)
效果:
tensor([[[[-0.8807, 0.1029, 0.0184, 1.2695],
[-0.0934, 1.0650, -0.2927, 0.0049],
[ 0.2338, -1.8663, 1.2763, 0.7248],
[-1.5138, 0.6834, 0.1463, 0.0650]],
[[ 0.5020, 1.6078, -0.0104, 1.2042],
[ 1.8859, -0.4682, -0.1177, 0.5197],
[ 1.7649, 0.4585, 0.6002, 0.3350],
[-1.1384, -0.0325, 0.8490, 0.6080]],
[[-0.5618, 0.5388, -0.0572, -0.7240],
[-0.3458, 1.3494, -0.0603, -1.1562],
[-0.3652, 1.1885, 1.6293, 0.4134],
[ 1.3009, 1.2027, -0.8711, 1.3321]]]], device='cuda:0')
tensor([[[[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 0, 0]],
[[1, 1, 0, 0],
[1, 0, 0, 1],
[1, 0, 0, 0],
[0, 0, 1, 0]],
[[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[1, 1, 0, 1]]]], device='cuda:0', dtype=torch.int32)
tensor([[[[-0.0000, 0.0000, 0.0184, 1.2695],
[-0.0000, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, 0.7248],
[-0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.5020, 1.6078, -0.0000, 0.0000],
[ 1.8859, -0.0000, -0.0000, 0.5197],
[ 1.7649, 0.0000, 0.0000, 0.0000],
[-0.0000, -0.0000, 0.8490, 0.0000]],
[[-0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 1.3494, -0.0603, -0.0000],
[-0.0000, 1.1885, 1.6293, 0.0000],
[ 1.3009, 1.2027, -0.0000, 1.3321]]]], device='cuda:0')
Process finished with exit code 0
参考链接:pytorch 只保留tensor的最大值或最小值,其他位置置零_bxdzyhx的博客-CSDN博客https://blog.csdn.net/bxdzyhx/article/details/120252197?ops_request_misc=&request_id=&biz_id=102&utm_term=torch%20%E7%89%B9%E5%AE%9A%E7%BB%B4%E5%BA%A6%E6%9C%80%E5%A4%A7%E5%80%BC%E5%8F%961%E5%85%B6%E4%BB%96%E7%BD%AE%E9%9B%B6&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-120252197.nonecase&spm=1018.2226.3001.4187
torch.max zhon