3/02/2022

flatten, unflatten Pytorch, torch, tensor reshape

 flatten -> unflatten


..

#batch, ch, row, col
input_size=torch.empty(10, 1, 256, 256)
print('input shape: ',input_size.shape)

#flatten batch, ch*row*col
reshape = torch.flatten(input_size, start_dim=1)
# reshape = input_size.reshape(-1, 1*256*256)
print('reshape or flatten: ',reshape.shape)

#unflatten batch, ch, row, col
reshape = reshape.reshape(-1,1,256,256)
print('unflatten: ',reshape.shape)

..


input shape:  torch.Size([10, 1, 256, 256])
reshape or flatten:  torch.Size([10, 65536])
unflatten:  torch.Size([10, 1, 256, 256])


Thank you.
πŸ™‡πŸ»‍♂️