1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| import torch from torch import nn import torch.nn.functional as F import math
class AttentionHead(nn.Module): def __init__(self, embed_dim, head_dim): super(AttentionHead, self).__init__() self.q = nn.Linear(embed_dim, head_dim) self.k = nn.Linear(embed_dim, head_dim) self.v = nn.Linear(embed_dim, head_dim)
def forward(self, query, key, value, mask=None): query, key, value = self.q(query), self.k(key), self.v(value) scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(query.size(-1)) if mask is not None: scores = scores.masked_fill(mask == 0, -float("inf")) weights = F.softmax(scores, dim=-1) return torch.bmm(weights, value) class MultiHeadAttention(nn.Module): def __init__(self, config): super(MultiHeadAttention, self).__init__() embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads self.heads = nn.ModuleList( [AttentionHead(embed_dim, head_dim) for i in range(num_heads)] ) self.output_linear = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, mask=None, query_mask=None, key_mask=None): if query_mask is not None and key_mask is not None: mask = torch.bmm(query_mask.unsqueeze(-1), key_mask.unsqueeze(1)) x = torch.cat([h(query, key, value, mask) for h in self.heads], dim=-1) x = self.output_linear(x) return x class FeedForward(nn.Module): def __init__(self, config): super(FeedForward, self).__init__() self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) self.gelu = nn.GELU() self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, x): x = self.linear1(x) x = self.gelu(x) x = self.linear2(x) x = self.gelu(x) x = self.dropout(x) return x class TransformerEncoderLayer(nn.Module): def __init__(self, config): super(TransformerEncoderLayer, self).__init__() self.layer_norm = nn.LayerNorm(config.hidden_size) self.attention = MultiHeadAttention(config) self.feedforward = FeedForward(config)
def forward(self, x, mask=None): hidden_state = self.layer_norm(x) x = x + self.attention(hidden_state, hidden_state, hidden_state, mask=mask) x = x + self.feedforward(self.layer_norm(x)) return x class Embeddings(nn.Module): def __init__(self, config): super(Embeddings, self).__init__() self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout()
def forward(self, input_ids): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
token_embeddings = self.token_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = token_embeddings + position_embeddings embeddings = self.layer_norm(embeddings) embeddings = self.dropout(embeddings) return embeddings class TransformerEncoder(nn.Module): def __init__(self, config): super(TransformerEncoder, self).__init__() self.embeddings = Embeddings(config) self.layers = nn.ModuleList([TransformerEncoderLayer(config) for i in range(config.num_hidden_layers)]) def forward(self, x, mask=None): x = self.embeddings(x) for layer in self.layers: x = layer(x, mask) return x if __name__=='__main__': from transformers import AutoConfig from transformers import AutoTokenizer
model_ckpt = "bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(model_ckpt) config = AutoConfig.from_pretrained(model_ckpt)
text1 = "time flies like an arrow" text2 = "I Love you"
inputs = tokenizer(text1 + text2, return_tensors="pt", add_special_tokens=False)
encoder = TransformerEncoder(config) print(encoder(inputs.input_ids).size())
|