300x250

 

 

목차

     

     

     

    Multi-modal, 특히 vision-language model에 관심을 갖게 되면서 관련 논문을 찾아보던 중, task-agnostic, modality-agnostic한 다양한 input을 다룰 수 있는 모델인 Perceiver와, 이에 이어 output까지 원하는 형태로 만들 수 있는 Perceiver IO를 접하게 되었다.

    Multi-modal 분야 뿐만 아니라 backbone으로(혹은 아이디어를 활용하여 구조를 변형하여) 훨씬 다양하게 활용될 수 있을 것 같아 두 아키텍쳐를 다룬 논문을 읽고 그 내용을 합쳐서 정리해보려 한다.

     

    먼저, Transformer의 핵심인 self-attention과 cross-attention이 여기서도 핵심 개념이기 때문에, 이를 다룬 글을 읽어보는 것을 추천한다.

    https://jjuke-brain.tistory.com/entry/%EB%94%A5%EB%9F%AC%EB%8B%9D-%EA%B8%B0%EC%B4%88-Attention-Transformer

     

    Transformers in Vision - (1) Attention & Transformer

    Transformer가 computer vision에서 어떻게 쓰였는지, 관련 모델이 어떻게 발전하고 있는지 여러 포스팅에 걸쳐서 알아보고자 한다. 이번 포스팅에서는 가장 중요한 기초 내용인 attention과 transformer에 대

    jjuke-brain.tistory.com

     

    그리고, 더 자세한 내용은 2021년 ICML에 발표된 Perceiver 논문 "Perceiver: General Perception with Iterative Attention"과 2022년 ICLR에 발표된 Perceiver IO 논문 "Perceiver IO: A General Architecture for Structured Inputs & Outputs"를 참조하자.

     

    이전 글에서는 general input을 다루는 Perceiver를 알아보았고, 이번에는 output까지 원하는 structure로 만들어낼 수 있는 Perceiver IO를 다뤄보려 한다.

     

     

     

    Method

     

    Fig 1. Perceiver IO pipeline

     

    이전 글에서도 언급했듯, Perceiver IO는 read-process-write (encode-process-decode) architecture이다.

    • Read ( \( \mathbf{x} \in \mathbb{R}^{M \times C} \rightarrow \mathbf{z} \in \mathbb{R}^{N \times D} \) ) : Input을 latent space로 인코딩하는 과정 (Perceiver의 cross-attention module과 같다.)
    • Process ( \( \mathbf{z} \rightarrow \mathbf{z}' \) ) : Latent representation을 정제(refine)하는 과정 (Perceiver의 Latent Transformer와 같다.)
    • Write ( \( \mathbf{z} \in \mathbb{R}^{N \times D} \rightarrow \mathbf{y} \in \mathbb{R}^{O \times E} \) ) : Latent space를 디코딩하는 과정
      • Read에서처럼 계산 과정과 output size를 분리하여 매우 큰 size도 출력이 가능하다!

     

    Read와 process 과정은 각각 Perceiver의 cross-attention module 및 self-attention module과 같으므로, write 과정을 살펴보자.

     

     

    Decoding

     

    Fig 2. Decoding in Perceiver IO

     

    Decoding 과정은 \(\mathbf{z} \in \mathbb{R}^{N \times D} \rightarrow \mathbf{y} \in \mathbb{R}^{O \times E}\)로 나타낼 수 있다. 이때 Perceiver때와 비슷하게 cross-attention을 적용한다.

     

    Fig 3. Encoding in Perceiver IO (Recap)

     

    Encoding 때(Fig 3)와 달리, latent가 key \(\mathbf{k} \in \mathbb{R}^{N \times d_\text{qk}}\) 및 value \( \mathbf{v} \in \mathbb{R}^{N \times d_\text{v}} \)가 되고, 원하는 structure(shape)의 output query array를 설정하여 query \(\mathbf{q} \in \mathbb{R}^{O \times d_\text{qk}}\)로 projection한다.

    최종 output array \(\mathbf{y} \in \mathbb{R}^{O \times E}\)의 shape dimension은 \(O\)로, query와 같다. \(O, E\)는 원하는 출력 데이터의 구조(shape)에 의해 정해지는 index dimension과 channel dimension을 나타낸다.

     

    Attention 계산의 복잡도를 생각해보면, Perceiver의 cross-attention(encoding) 때와 마찬가지로 latent를 활용함으로써 \(\mathcal{O}(ON)\)이 되며, 원하는 output size에 linear하게 증가한다. 따라서 복잡도가 \(O\)에 2차 비례하던 기존 Transformer에 비해 output size를 훨씬 크게 설정할 수 있다.

     

     

    Composition of the Output Query Array

     

    Fig 4. Output query arrays corresponding to specific tasks

     

    Fig 4는 원하는 task(dataset)에 따라 output query array를 어떻게 구성하는지 보여준다. Index dimension이 \(O\)인 query로 디코더를 querying하는데, 이 query array는 output space의 구조(shape, structure)를 담아야 한다. 즉, 각 output point의 정보(spatial position, modality 등)를 포함해야 한다.

    이를 위해서는 해당 정보를 serialize한 후, concat하거나 더해준다. Fig 4의 예시를 통해 자세히 알아보자.

    • 간단한 output (classification에서의 category(label) 등)
      • 모든 data example에 대해 같은 query를 사용할 수 있다. (즉, position encoding을 할 필요가 없다.) 따라서 Fig 4에서 Classification을 보면 '@... positions'라는 표현이 없는 것을 확인할 수 있다.
    • Spatial 혹은 sequential한 구조를 갖는 output (Text, Image 등)
      • Position encoding을 통해 위치(position) 정보를 나타내준다.
    • Multi-task 혹은 multimodal output (Fig 4에서 Multi-task classification 및 Multimodal autoencoding)
      • Task나 modality에 맞게 single query를 학습한 후에 사용한다.
      • 각 modality나 task에 해당하는 embedding(is_video, is_audio, is_label)을 concat해준다.
      • Position encoding이 위치를 구별하듯이, network가 task나 modality를 구별하도록 학습하는 개념이다.
    • Other tasks
      • Output이 input 내용을 반영하도록 query를 구성한다.
      • Otical flow같은 경우, 2D spatial 정보(x, y) 뿐만 아니라 input feature를 포함시켰더니 성능이 좋았다고 한다.
      • Starcraft 2의 경우, unit information을 포함시켰더니 성능이 좋았다고 한다.
      • 이와 같이, 간단한 query에 추가적인 정보를 학습하도록 해주면 더 좋은 성능을 보이는 경우가 있다.

     

    Decoder 코드(깃허브에 들어가보면 PerceiverClassificationDecoder, PerceiverOpticalFlowDecoder, PerceiverMultimodalDecoder 등 다양한 decoder가 있다. 그중 가장 기본적인 디코더이다.)는 아래와 같다.

     

    class PerceiverBasicDecoder(PerceiverAbstractDecoder):
        """
        Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a
        cross-attention operation, in which the latents produce keys and values.
    
        The shape of the output of this class depends on how one defines the output queries (also called decoder queries).
    
        Args:
            config ([*PerceiverConfig*]):
                Model configuration.
            output_num_channels (`int`, *optional*):
                The number of channels in the output. Will only be used in case *final_project* is set to `True`.
            position_encoding_type (`str`, *optional*, defaults to "trainable"):
                The type of position encoding to use. Can be either "trainable", "fourier", or "none".
            output_index_dims (`int`, *optional*):
                The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
            num_channels (`int`, *optional*, defaults to 128):
                The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
            qk_channels (`int`, *optional*):
                The number of channels of the queries and keys in the cross-attention layer.
            v_channels (`int`, *optional*):
                The number of channels of the values in the cross-attention layer.
            num_heads (`int`, *optional*, defaults to 1):
                The number of attention heads in the cross-attention layer.
            widening_factor (`int`, *optional*, defaults to 1):
                The widening factor of the cross-attention layer.
            use_query_residual (`bool`, *optional*, defaults to `False`):
                Whether to use a residual connection between the query and the output of the cross-attention layer.
            concat_preprocessed_input (`bool`, *optional*, defaults to `False`):
                Whether to concatenate the preprocessed input to the query.
            final_project (`bool`, *optional*, defaults to `True`):
                Whether to project the output of the cross-attention layer to a target dimension.
            position_encoding_only (`bool`, *optional*, defaults to `False`):
                Whether to only use this class to define output queries.
        """
    
        def __init__(
            self,
            config: PerceiverConfig,
            output_num_channels: int,
            position_encoding_type: Optional[str] = "trainable",
            # The following 2 arguments are ignored if position_encoding_type == 'none':
            output_index_dims: Optional[int] = None,
            num_channels: Optional[int] = 128,
            subsampled_index_dims: Optional[int] = None,
            qk_channels: Optional[int] = None,
            v_channels: Optional[int] = None,
            num_heads: Optional[int] = 1,
            widening_factor: Optional[int] = 1,
            use_query_residual: Optional[bool] = False,
            concat_preprocessed_input: Optional[bool] = False,
            final_project: Optional[bool] = True,
            position_encoding_only: Optional[bool] = False,
            **position_encoding_kwargs,
        ) -> None:
            super().__init__()
    
            self.output_num_channels = output_num_channels
            # If `none`, the decoder will not construct any position encodings.
            # You should construct your own when querying the decoder.
            self.output_position_encodings = None
            self.position_encoding_type = position_encoding_type
            self.position_encoding_kwargs = position_encoding_kwargs
            if position_encoding_type != "none":
                self.output_position_encodings, self.positions_projection = build_position_encoding(
                    position_encoding_type=position_encoding_type, **position_encoding_kwargs
                )
    
            self.output_index_dims = output_index_dims
            self.num_channels = num_channels
            if subsampled_index_dims is None:
                subsampled_index_dims = output_index_dims
            self.subsampled_index_dims = subsampled_index_dims
            self.concat_preprocessed_input = concat_preprocessed_input
            self.final_project = final_project
            self.position_encoding_only = position_encoding_only
    
            # for multimodal autoencoding, we don't need the decoder cross-attention and final layer
            # so then we will set position_encoding_only to True
            if not self.position_encoding_only:
                self.decoding_cross_attention = PerceiverLayer(
                    config,
                    is_cross_attention=True,
                    qk_channels=qk_channels,
                    v_channels=v_channels,
                    num_heads=num_heads,
                    q_dim=num_channels,
                    kv_dim=config.d_latents,
                    widening_factor=widening_factor,
                    use_query_residual=use_query_residual,
                )
                self.final_layer = nn.Linear(num_channels, output_num_channels) if final_project else nn.Identity()
    
        @property
        def num_query_channels(self) -> int:
            if self.position_encoding_type == "none":  # Queries come from elsewhere
                raise ValueError(
                    "You cannot calculate number of decoder query channels when position_encoding_type is set to none"
                )
            if self.position_encoding_only:
                if "project_pos_dim" in self.position_encoding_kwargs:
                    return self.position_encoding_kwargs["project_pos_dim"]
                return self.output_position_encodings.output_size()
            if self.final_project:
                return self.output_num_channels
            return self.num_channels
    
        def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
            if self.position_encoding_type == "none":  # Queries come from elsewhere
                raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none")
            if subsampled_points is not None:
                # subsampled_points are the indices if the inputs would be flattened
                # however, the inputs aren't flattened, that's why we use unravel_index
                # to get the indices for the unflattened array
                # unravel_index returns a tuple (x_idx, y_idx, ...)
                # stack to get the [n, d] tensor of coordinates
                indices = [torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)]
                pos = torch.stack(indices, dim=1)
                batch_size = inputs.shape[0]
                # Map these coordinates to [-1, 1]
                pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
                pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]])
                # Construct the position encoding.
                if self.position_encoding_type == "trainable":
                    pos_emb = self.output_position_encodings(batch_size)
                elif self.position_encoding_type == "fourier":
                    pos_emb = self.output_position_encodings(
                        self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
                    )
    
                # Optionally project them to a target dimension.
                pos_emb = self.positions_projection(pos_emb)
                pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
            else:
                batch_size = inputs.shape[0]
                index_dims = inputs.shape[2:]
    
                # Construct the position encoding.
                if self.position_encoding_type == "trainable":
                    pos_emb = self.output_position_encodings(batch_size)
                elif self.position_encoding_type == "fourier":
                    pos_emb = self.output_position_encodings(
                        index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
                    )
    
                # Optionally project them to a target dimension.
                pos_emb = self.positions_projection(pos_emb)
    
            if self.concat_preprocessed_input:
                if inputs_without_pos is None:
                    raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
                pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)
    
            return pos_emb
    
        def forward(
            self,
            query: torch.Tensor,
            z: torch.FloatTensor,
            query_mask: Optional[torch.FloatTensor] = None,
            output_attentions: Optional[bool] = False,
        ) -> PerceiverDecoderOutput:
            # Cross-attention decoding.
            # key, value: B x N x K; query: B x M x K
            # Attention maps -> B x N x M
            # Output -> B x M x K
            cross_attentions = () if output_attentions else None
    
            layer_outputs = self.decoding_cross_attention(
                query,
                attention_mask=query_mask,
                head_mask=None,
                inputs=z,
                inputs_mask=None,
                output_attentions=output_attentions,
            )
            output = layer_outputs[0]
    
            if output_attentions:
                cross_attentions = cross_attentions + (layer_outputs[1],)
    
            logits = self.final_layer(output)
    
            return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions)

     

     

     

     

    Perceiver IO 코드

     

    Perceiver IO가 동작하는 과정을 전체적으로 담은 PerceiverModel 클래스를 요약한 코드를 살펴보자.

     

    class PerceiverModel(PerceiverPreTrainedModel):
        def __init__(
            self,
            config,
            decoder=None,
            input_preprocessor: PreprocessorType = None,
            output_postprocessor: PostprocessorType = None,
        ):
            super().__init__(config)
            self.config = config
    
            self.input_preprocessor = input_preprocessor
            self.output_postprocessor = output_postprocessor
            self.embeddings = PerceiverEmbeddings(config)
            self.encoder = PerceiverEncoder(
                config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
            )
            self.decoder = decoder
    
            # Initialize weights and apply final processing
            self.post_init()
    
        def get_input_embeddings(self):
            return self.embeddings.latents
    
        def set_input_embeddings(self, value):
            self.embeddings.latents = value
    
        def _prune_heads(self, heads_to_prune):
            """
            Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
            class PreTrainedModel
            """
            for layer, heads in heads_to_prune.items():
                self.encoder.layer[layer].attention.prune_heads(heads)
    
        @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
        @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC)
        def forward(
            self,
            inputs: torch.FloatTensor,
            attention_mask: Optional[torch.FloatTensor] = None,
            subsampled_output_points: Optional[Dict[str, torch.Tensor]] = None,
            head_mask: Optional[torch.FloatTensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ) -> Union[Tuple, PerceiverModelOutput]:
            ...
            batch_size, seq_length, _ = inputs.size()
            device = inputs.device
    
            ... # input preprocessing
    
            embedding_output = self.embeddings(batch_size=batch_size)
    
            encoder_outputs = self.encoder(
                ...
            )
            sequence_output = encoder_outputs[0]
    
            logits = None
            if self.decoder:
                if subsampled_output_points is not None:
                    output_modality_sizes = {
                        "audio": subsampled_output_points["audio"].shape[0],
                        "image": subsampled_output_points["image"].shape[0],
                        "label": 1,
                    }
                else:
                    output_modality_sizes = modality_sizes
                decoder_query = self.decoder.decoder_query(
                    inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
                )
                decoder_outputs = self.decoder(
                    ...
                )
                logits = decoder_outputs.logits
    
                # add cross-attentions of decoder
                if output_attentions and decoder_outputs.cross_attentions is not None:
                    if return_dict:
                        encoder_outputs.cross_attentions = (
                            encoder_outputs.cross_attentions + decoder_outputs.cross_attentions
                        )
                    else:
                        encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions
                 
                 ... # postprocessing
    
            return PerceiverModelOutput(
                logits=logits,
                last_hidden_state=sequence_output,
                hidden_states=encoder_outputs.hidden_states,
                attentions=encoder_outputs.attentions,
                cross_attentions=encoder_outputs.cross_attentions,
            )

     

    728x90
    • 네이버 블러그 공유하기
    • 네이버 밴드에 공유하기
    • 페이스북 공유하기
    • 카카오스토리 공유하기