你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

torch.max讲解-- 指定维度下,数值最大位置设为1其他为0

2021-11-20 12:38:46

x是输入张量
dim指定维度
max表示取最大值


import torch

if __name__ == '__main__':
    
    x = torch.randn([1, 3, 4, 4]).cuda()

    mask = (x == x.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
    result = torch.mul(mask, x)

    print(x)
    print(mask)
    print(result)

效果:

tensor([[[[-0.8807,  0.1029,  0.0184,  1.2695],
          [-0.0934,  1.0650, -0.2927,  0.0049],
          [ 0.2338, -1.8663,  1.2763,  0.7248],
          [-1.5138,  0.6834,  0.1463,  0.0650]],

         [[ 0.5020,  1.6078, -0.0104,  1.2042],
          [ 1.8859, -0.4682, -0.1177,  0.5197],
          [ 1.7649,  0.4585,  0.6002,  0.3350],
          [-1.1384, -0.0325,  0.8490,  0.6080]],

         [[-0.5618,  0.5388, -0.0572, -0.7240],
          [-0.3458,  1.3494, -0.0603, -1.1562],
          [-0.3652,  1.1885,  1.6293,  0.4134],
          [ 1.3009,  1.2027, -0.8711,  1.3321]]]], device='cuda:0')
tensor([[[[0, 0, 1, 1],
          [0, 0, 0, 0],
          [0, 0, 0, 1],
          [0, 0, 0, 0]],

         [[1, 1, 0, 0],
          [1, 0, 0, 1],
          [1, 0, 0, 0],
          [0, 0, 1, 0]],

         [[0, 0, 0, 0],
          [0, 1, 1, 0],
          [0, 1, 1, 0],
          [1, 1, 0, 1]]]], device='cuda:0', dtype=torch.int32)
tensor([[[[-0.0000,  0.0000,  0.0184,  1.2695],
          [-0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.7248],
          [-0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5020,  1.6078, -0.0000,  0.0000],
          [ 1.8859, -0.0000, -0.0000,  0.5197],
          [ 1.7649,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.8490,  0.0000]],

         [[-0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  1.3494, -0.0603, -0.0000],
          [-0.0000,  1.1885,  1.6293,  0.0000],
          [ 1.3009,  1.2027, -0.0000,  1.3321]]]], device='cuda:0')

Process finished with exit code 0

参考链接:pytorch 只保留tensor的最大值或最小值,其他位置置零_bxdzyhx的博客-CSDN博客https://blog.csdn.net/bxdzyhx/article/details/120252197?ops_request_misc=&request_id=&biz_id=102&utm_term=torch%20%E7%89%B9%E5%AE%9A%E7%BB%B4%E5%BA%A6%E6%9C%80%E5%A4%A7%E5%80%BC%E5%8F%961%E5%85%B6%E4%BB%96%E7%BD%AE%E9%9B%B6&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-0-120252197.nonecase&spm=1018.2226.3001.4187

torch.max zhon