|
| 1 | +from typing import Any |
| 2 | +from .. import Layer |
| 3 | +from ..util import Linear |
| 4 | +from tensor_array.core import Tensor |
| 5 | + |
| 6 | +def scaled_dot_product_attention(q, k, v, mask = None): |
| 7 | + attn_scores = q @ k.transpose(len(k.shape()) - 2, len(k.shape()) - 1) |
| 8 | + attn_probs = SoftMax(attn_scores, len(attn_scores.shape()) - 1) |
| 9 | + return attn_probs @ v |
| 10 | + |
| 11 | +class MultiheadAttention(Layer): |
| 12 | + def __init__(self, d_model, n_head) -> None: |
| 13 | + super().__init__() |
| 14 | + self.linear_q = Linear(d_model) |
| 15 | + self.linear_k = Linear(d_model) |
| 16 | + self.linear_v = Linear(d_model) |
| 17 | + self.linear_o = Linear(d_model) |
| 18 | + self.n_head = n_head |
| 19 | + |
| 20 | + def calculate(self, input_q, input_k, input_v, mask = None) -> Any: |
| 21 | + temp_q = self.linear_q(input_q) |
| 22 | + temp_k = self.linear_k(input_k) |
| 23 | + temp_v = self.linear_v(input_v) |
| 24 | + |
| 25 | + temp_q = temp_q.reshape((temp_q.shape()[0], temp_q.shape()[1], self.n_head, temp_q.shape()[-1] / self.n_head)).transpose(1, 2) |
| 26 | + temp_k = temp_k.reshape((temp_k.shape()[0], temp_k.shape()[1], self.n_head, temp_k.shape()[-1] / self.n_head)).transpose(1, 2) |
| 27 | + temp_v = temp_v.reshape((temp_v.shape()[0], temp_v.shape()[1], self.n_head, temp_v.shape()[-1] / self.n_head)).transpose(1, 2) |
| 28 | + |
| 29 | + attention_output = scaled_dot_product_attention(temp_q, temp_k, temp_v, mask) |
| 30 | + |
| 31 | + attention_output = attention_output.transpose(1, 2) |
| 32 | + attention_output = attention_output.reshape((temp_q.shape()[0], temp_q.shape()[1], temp_q.shape[-2] * temp_q.shape[-1])) |
| 33 | + return self.linear_o(attention_output) |
0 commit comments