torch.argmax(tensor, dim=0)

최대값의 인덱스를 뽑을 수 있다.

a = torch.randn(4, 4)
print(a)
output = torch.argmax(a, dim=0)
print(output)

#### output ####
'''
tensor([[-1.6370,  1.4183, -0.1544, -0.7080],
        [ 1.6758, -0.1570,  0.5589, -1.5919],
        [-0.3721,  1.6971,  0.1501, -0.0780],
        [ 0.6539,  0.1301,  0.6457,  0.8172]])
tensor([1, 2, 3, 3])
'''
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a)
print(output)

#### output #####
'''
tensor([[-0.5014, -0.1785,  0.2534,  0.7167],
        [-0.7887,  1.0920,  0.5385, -1.1797],
        [-1.0129,  0.2337,  0.5757,  0.9139],
        [ 1.4672, -1.0605, -0.1456, -0.9329]])
tensor(12)
'''