최대값의 인덱스를 뽑을 수 있다.
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a, dim=0)
print(output)
#### output ####
'''
tensor([[-1.6370, 1.4183, -0.1544, -0.7080],
[ 1.6758, -0.1570, 0.5589, -1.5919],
[-0.3721, 1.6971, 0.1501, -0.0780],
[ 0.6539, 0.1301, 0.6457, 0.8172]])
tensor([1, 2, 3, 3])
'''
a = torch.randn(4, 4)
print(a)
output = torch.argmax(a)
print(output)
#### output #####
'''
tensor([[-0.5014, -0.1785, 0.2534, 0.7167],
[-0.7887, 1.0920, 0.5385, -1.1797],
[-1.0129, 0.2337, 0.5757, 0.9139],
[ 1.4672, -1.0605, -0.1456, -0.9329]])
tensor(12)
'''