NLP study

[Pytorch] pad_sequence와 pack_padded_sequence 그리고 collate_fn

2로 접어듦 2023. 2. 22. 21:17

1. pad_sequence, pack_padded_sequence 란 무엇인가?

 

딥러닝 모델 학습에는 batch_size 설정이 필수적이다. GPU를 이용하여 computation을 parallelize할 수 있기 때문이다.

자연어처리에서는 input data의 길이가 variable하므로, batch를 설정할 때, 사진과 같이, 5개의 text를 하나의 batch 라고 할 때, 문장의 길이가 다르기 때문에 빈 공간이 생기게 된다(일반적으로 padding 처리를 한다.)

문장의 길이가 다른 것을 same input feature로 만들기 위해서 padding처리를 한 뒤 RNN모델과 같은 DL 모델에 input으로 들어가게 되는데, 이 때 이러한 input batch는 다음과 같은 문제점을 갖는다.

 

Variable length of input의 문제점

  • Padding을 추가한다면 padding이 학습되고 값이 변형되어 Model로 feed 되어, 최종 결과물에 영향을 주게 된다.
  • Padding을 추가하지 않고 paralellize 하게 계산하기 위해서는, input에 대한 previous time step을 저장해야한다(the end of each sequence).

 

pack_padded_sequence는 이러한 문제를 해결하는 데 도움을 줄 수 있다. 임의로 추가된 pad를 무시하고, batch로 형성된 input을 한 차원 낮춰 압축하는 형태로 processing 한다.정리하면 다음과 같은 순서를 거친다.

  1. pad_sequence를 이용해서 input data는 uniform sequence length로 정리된다.
  2. pad_sequence와 original sequence length로 어디에 padding이 더해졌는지를 알 수 있다.
  3. pack_padded_sequence는 pad를 제거한 sequence를 packing한다(차원이 아마 하나 줄어드는 것으로 이해함.)

 

pack_padded_sequence는 다음과 같은 파라미터를 받는다.

  • input(Tensor): pad 처리 된 batch of variable length sequences.
  • lenghts(Tensor or list(int)): batch 형태로 이루어진 input의 각각의 length를 list 형태로 받는다. tensor로 주면, cpu연산을 해야한다.
  • batch_first
  • enforce_sorted: input batch 를 길이 순으로 내림차순 정렬을 하지 않았다면 False 처리해주어야 한다. 디폴트가 True이다.

참고: https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html

 

이렇게만 보면 무슨 말인지 모른다. 예제 코드를 보자.

import torch
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# 서로 다른 길이의 sequence를 정의하자. 순서대로 (6, 10) (4, 10) (2, 10)인 tensor가 리스트로 묶여있다.
sequences = [torch.randn(i, 10) for i in [6, 4, 2]]

# 서로 다른 길이의 input을 맞춰주기 위해 padding 처리를 한다.
# 리턴 shape는 torch.Size([3, 6, 10]) 으로, batch * 가장 긴 length * feature 이다.
padded_sequences = pad_sequence(sequences, batch_first=True)

# 서로 다른 길이의 각 sequence의 길이 정보가 필요하다. 이를 리스트 형태로 저장한다.
lengths = [len(seq) for seq in sequences]

# packing 부분. pad sequence와 각 sequence의 길이, 등등을 param으로 받는다.
packed_seq = pack_padded_sequence(padded_sequences, lengths, batch_first=True, enforce_sorted=False)

# LSTM 을 정의할 때, input feature는 위에서 봤던 것처럼 10 이다. hidden 은 임의로 정할 수 있다.
lstm = torch.nn.LSTM(10, 20, batch_first=True)

# Pack 된 input data를 lstm 모델에 input 으로 넣을 수 있다.
packed_output, (h_n, c_n) = lstm(packed_seq)

# 결과 사이즈 설명
# packed_output은 PackedSequence 타입이다.
# h_n은 torch.Size([1, 3, 20]) 형태로, 1 은 pack의 영향으로 생긴 것으로 추측하며(불확실), batch * hidden 은 나머지.
# c_n은 torch.Size([1, 3, 20]), cell_state.

# Unpack the output sequence
output, _ = pad_packed_sequence(packed_output, batch_first=True)

print(output.shape)
# output의 shape은 torch.Size([3, 6, 20])
# 실제로는 이렇게 unpack할 일이 없는 것 같다. pack을 바로바로 사용할 수 있는 점이 장점이다.

 

2. collate_fn 이란?

torch.utils.data.DataLoader에는 collate_fn이라는 옵션이 있다.

collate는 '통합하다'라는 뜻의 단어로, DataLoader에서는 list of samples를 batch 형태로 stacking하기 위한 argument이다.

DL 모델에서는 Parallizing 을 위해 Batch를 만들어야하므로, 필수적인 옵션이다. 내부적으로 collate 함수가 동작하나, customizing이 가능하다.

 

예제 코드는 아래와 같다.

import torch

def collate_fn(batch):
    # Unzip the batch into separate lists of inputs and targets
    inputs, targets = zip(*batch)

    # Stack the inputs and targets into tensors
    inputs = torch.stack(inputs)
    targets = torch.stack(targets)

    # Return the batched sample as a tuple of tensors
    return inputs, targets

input과 target을 따로 stacking하여 리턴하는 것을 확인할 수 있다. 리턴되는 값들은 batch사이즈의 sample이다.

collate_fn 은 custom 함수로 만들 수 있다. 예를 들면, input sequence의 길이가 달라 단순한 stacking이 어려운 경우 pacd_sequence와 결합되어 사용되기도 한다. 

 

class MyCollate:
    def __init__(self, pad_idx, batch_first):
        self.pad_idx = pad_idx
        self.batch_first = batch_first

    #__call__: a default method
    ##   First the obj is created using MyCollate(pad_idx) in data loader
    ##   Then if obj(batch) is called -> __call__ runs by default
    
    def __call__(self, batch):
    	# 2. batch 에서 서로 다른 길이의 source를 위해 padding을 진행한다.
        # get all source indexed sentences of the batch
        source = [item[0] for item in batch]
        # pad them using pad_sequence method from pytorch. 
        source = pad_sequence(source, batch_first=self.batch_first, padding_value = self.pad_idx) 
        
        # get all target indexed sentences of the batch
        target = [item[1] for item in batch] 
        # pad them using pad_sequence method from pytorch. 
        target = pad_sequence(target, batch_first=self.batch_first, padding_value = self.pad_idx)
        return source, target

# 0. dataset의 sequence가 서로 다른 길이를 가지고 있다고 가정한다. 여기서 dataset은 이미 batch형태로 만들어져 있다.
# 1. DataLoader에서 custom collate_fn 함수를 호출한다.
loader = DataLoader(dataset, batch_size = batch_size, shuffle=shuffle,
                        collate_fn = MyCollate(pad_idx=pad_idx, batch_first=batch_first)) #MyCollate class runs __call__ method by default

음, 여기 collate에서 stack은 왜 없는지는 확인해봐야한다.

DataLoader자체가 batch를 만들기 때문인가?

 

 

 

참고한 사이트

1. 이미지 출처: https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983

2. https://pytorch.org/docs/stable/data.html

3. https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html