flatten -> unflatten
..
#batch, ch, row, col
input_size=torch.empty(10, 1, 256, 256)
print('input shape: ',input_size.shape)
#flatten batch, ch*row*col
reshape = torch.flatten(input_size, start_dim=1)
# reshape = input_size.reshape(-1, 1*256*256)
print('reshape or flatten: ',reshape.shape)
#unflatten batch, ch, row, col
reshape = reshape.reshape(-1,1,256,256)
print('unflatten: ',reshape.shape)
..
input shape: torch.Size([10, 1, 256, 256]) reshape or flatten: torch.Size([10, 65536]) unflatten: torch.Size([10, 1, 256, 256])
Thank you.
๐๐ป♂️
No comments:
Post a Comment