tril을 통해 lower traingular matrix를 만들 수 있다.
attention_mask = torch.ones(2,2).tril(diagonal=0)
attention "mask"라는 이름처럼, attention mask는 1인 부분의 data만 남기고 0인 부분의 data를 지우는 데 사용한다.
혹은, masked_fill을 사용해 traingular matrix의 특정 부분을 원하는 값으로 바꿀 수 있다. 아래의 코드는 0인 부분을 -infinity로 바꾸는 코드이다.
attention_mask = attention_mask.masked_fill(attention_mask==0, -float('inf'))
attention_mask
attention mask에 softmax를 취해서 자신이 알고 있는 정보를 평균으로 구할 수 있다. 앞선 포스팅에서 들었던 character 예시에서 첫 번째 character는 자기 자신 이외의 정보를 알 수 없고, 두 번째 character는 자기 자신과 첫 번째 character에 대한 정보를 알 수 있었던 것을 떠올려 보자.
attention_mask = torch.ones(4,4).tril(diagonal=0)
attention_mask = attention_mask.masked_fill(attention_mask==0, -float('inf'))
torch.softmax(attention_mask, dim=-1)
자신이 정보를 알고 있는 data 수를 가지고 1을 평균내서 표시하였다.
앞선 포스팅에서 scaled dot product를 구하기 위해 사용했던 코드를 응용해서 attention mask가 어떤 방식으로 machine learning에 사용되는지 예를 들어 보자.
torch.softmax((Q@K.transpose(-2,-1)/math.sqrt(Q.size(-1)))+attention_mask, dim=-1)
위 결과값은 각 정보에 주어질 weight의 확률을 나타내고, 이는 다른 data와 얼마나 연관이 있는지 나타낸다. 이를 machine learning에서는 attention이라고 하고 이를 확률의 형태로 계산하는 것이다.
결과적으로 scaled dot product에 attention mask를 씌우면, 그 결과값은 data가 다른 data에 대해 알고 있는 정보에 대한 weighted average가 산출된다.
앞서 Q(query), K(key), V(value)를 Multi-Head Attention에 전달하였다. 이 3 값은 모두 embedding 된 값으로, 이들을 통해 scared dot product attention을 구하고, 이것들을 concatenate하여 다시 linear layer에 전달하는 것이 Multi-Head Attention이다.
즉 우리는 여기까지 scaled dot product가 무엇인지, attention은 무엇이고 attention mask는 어떤 역할을 하는지 알아보았다. 이 역할을 하는 code는 아래와 같다. 이해를 위해 직접 구현한 scaled dot product를 예시로 많이 들었지만 torch.nn.functional에서 scaled_dot_product_attention을 지원한다는 사실을 기억하자.
F.scaled_dot_product_attention(Q,K,V,attention_mask)
이 때, 만일 각 data 사이에 연관이 있다면(ex.앞 뒤로 서로 영향을 미치는 경우) 여기에 is_causal option을 true로 설정해 주어야 한다.
F.scaled_dot_product_attention(Q,K,V,attention_mask, is_causal=True)
'경제학코딩 2023' 카테고리의 다른 글
[Decoder Model] (1) | 2023.12.17 |
---|---|
[Decoder Data Step] (1) | 2023.12.17 |
[Scaled dot product] (1) | 2023.12.17 |
[Encoder Only Model with imdb] (0) | 2023.12.17 |
[Encoder Only Model with Random Data] (1) | 2023.12.17 |