[D] Trying to use bert for simple classification, but it doesnt work, please help
So, the idea is simple: feed bert context vector to gru layer.
class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.bert = BertModel.from_pretrained('bert-base-uncased') freeze(self.bert) self.rnn = torch.nn.GRU(768, 128, 1, batch_first=True, bidirectional=False) self.linear = torch.nn.Linear(128, 5) def forward(self, x, lengths): x, _ = self.bert(x) x = torch.nn.utils.rnn.pack_padded_sequence( x, lengths, batch_first=True) self.rnn.flatten_parameters() _, x = self.rnn(x) x = self.linear(x.squeeze(0)) return x
But model loss do not decrease.
submitted by /u/hadaev
[link] [comments]