https://github.com/pbpo/dumcoder/blob/main/tenarytransformers.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TernaryEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(TernaryEmbedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
def forward(self, input):
weight_ternary = self.ternary_weight()
return nn.functional.embedding(input, weight_ternary)
def ternary_weight(self):
abs_mean = self.weight.abs().mean()
mask = self.weight.abs() > abs_mean
weight_ternary = torch.where(mask, self.weight.sign(), torch.zeros_like(self.weight))
return weight_ternary
class TernaryLinear(nn.Module):
def __init__(self, in_features, out_features):
super(TernaryLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
def ternary_weight(self):
abs_mean = self.weight.abs().mean()
mask = self.weight.abs() > abs_mean
weight_ternary = torch.where(mask, self.weight.sign(), torch.zeros_like(self.weight))
return weight_ternary
def forward(self, input):
weight_ternary = self.ternary_weight()
return F.linear(input, weight_ternary)
class TernaryMultiheadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(TernaryMultiheadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.query_proj = TernaryLinear(d_model, d_model)
self.key_proj = TernaryLinear(d_model, d_model)
self.value_proj = TernaryLinear(d_model, d_model)
self.out_proj = TernaryLinear(d_model, d_model)
def forward(self, query, key, value, attn_mask=None):
batch_size = query.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attn_mask is not None:
attn_scores = attn_scores.masked_fill(attn_mask == 0, -1e9)
attn_probs = F.softmax(attn_scores, dim=-1)
attn_output = torch.matmul(attn_probs, value).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
attn_output = self.out_proj(attn_output)
return attn_output
class TernaryFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(TernaryFeedForward, self).__init__()
self.linear1 = TernaryLinear(d_model, d_ff)
self.linear2 = TernaryLinear(d_ff, d_model)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
return x
class TernaryTransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TernaryTransformerEncoderLayer, self).__init__()
self.self_attn = TernaryMultiheadAttention(d_model, num_heads)
self.feed_forward = TernaryFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output = self.self_attn(x, x, x, attn_mask=mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
ff_output = self.feed_forward(x)
x = x + self.dropout2(ff_output)
x = self.norm2(x)
return x
class TernaryTransformerEncoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(TernaryTransformerEncoder, self).__init__()
self.layers = nn.ModuleList([
TernaryTransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
def forward(self, x, mask=None):
for layer in self.layers:
x = layer(x, mask)
return x
class TernaryTransformerDecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TernaryTransformerDecoderLayer, self).__init__()
self.self_attn = TernaryMultiheadAttention(d_model, num_heads)
self.enc_dec_attn = TernaryMultiheadAttention(d_model, num_heads)
self.feed_forward = TernaryFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, enc_output, self_attn_mask=None, enc_dec_attn_mask=None):
self_attn_output = self.self_attn(x, x, x, attn_mask=self_attn_mask)
x = x + self.dropout1(self_attn_output)
x = self.norm1(x)
enc_dec_attn_output = self.enc_dec_attn(x, enc_output, enc_output, attn_mask=enc_dec_attn_mask)
x = x + self.dropout2(enc_dec_attn_output)
x = self.norm2(x)
ff_output = self.feed_forward(x)
x = x + self.dropout3(ff_output)
x = self.norm3(x)
return x
class TernaryTransformerDecoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super(TernaryTransformerDecoder, self).__init__()
self.layers = nn.ModuleList([
TernaryTransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
def forward(self, x, enc_output, self_attn_mask=None, enc_dec_attn_mask=None):
for layer in self.layers:
x = layer(x, enc_output, self_attn_mask, enc_dec_attn_mask)
return x
class TernaryTransformer(nn.Module):
def __init__(self, num_enc_layers, num_dec_layers, d_model, num_heads, d_ff, input_vocab_size, output_vocab_size, max_seq_len, dropout=0.1):
super(TernaryTransformer, self).__init__()
self.encoder = TernaryTransformerEncoder(num_enc_layers, d_model, num_heads, d_ff, dropout)
self.decoder = TernaryTransformerDecoder(num_dec_layers, d_model, num_heads, d_ff, dropout)
self.enc_embedding = TernaryEmbedding(input_vocab_size, d_model)
self.dec_embedding = TernaryEmbedding(output_vocab_size, d_model)
self.position_enc = PositionalEncoding(d_model, max_seq_len, dropout)
self.fc = TernaryLinear(d_model, output_vocab_size)
def forward(self, src, tgt, src_mask=None, tgt_mask=None, enc_dec_attn_mask=None):
src_embedded = self.enc_embedding(src)
src_embedded = self.position_enc(src_embedded)
enc_output = self.encoder(src_embedded, src_mask)
tgt_embedded = self.dec_embedding(tgt)
tgt_embedded = self.position_enc(tgt_embedded)
dec_output = self.decoder(tgt_embedded, enc_output, tgt_mask, enc_dec_attn_mask)
output = self.fc(dec_output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
초보라 mask인자 빠져있고 decoder에서도 빠져있음.
그리고 dropout랑 layernorm은 기존꺼 씀.