[Decoder Model]
class SimpleModel(nn.Module): def __init__(self): super().__init__() self.tok_emb = nn.Embedding(vocab_size, n_emb) self.pos_emb = nn.Embedding(block_size, n_emb) self.blocks = nn.Sequential( Block(n_emb, n_head=4), Block(n_emb, n_head=4), Block(n_emb, n_head=4), ) self.lm_head = nn.Linear(n_emb, vocab_size) def forward(self, idx, targets=None): B, T = idx.shape tok_emb = self.tok_emb(idx) # (B,..
2023. 12. 17.