unit8co/darts

[BUG] Implementation errors on `TransformerModel`.

Open

#672 opened on 2021年12月6日

GitHub で見る
 (0 comments) (0 reactions) (0 assignees)Python (6,832 stars) (762 forks)batch import
buggood first issue

説明

I have looked into [Transformer](https://github.com/unit8co/darts/blame/master/darts/models/forecasting/transformer_model.py) and I have found some errors.

Frist,

In line 167, 170,

src = self.encoder(src) * math.sqrt(self.input_size)
tgt = self.encoder(tgt) * math.sqrt(self.input_size)

I don't think we have to multiply math.sqrt(self.input_size) to inputs (src or tgt). Because torch.nn.MultiheadAttention take cares this normalization.

Second,

In line 173 - 174,

        x = self.transformer(src=src,
                             tgt=tgt)

There is no tgt_mask for this prediction. In order to use teacher forcing at training stage, user must feed tgt_mask to forward function (specifically square_subsequent_mask defined below). Otherwise decoder inputs before time t can see future decoder inputs (e.g, t+1, t+2, ...) which doesn't exist at inference stage.

[docs]    @staticmethod
    def generate_square_subsequent_mask(sz: int) -> Tensor:
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

I'm not sure these things are errors. But, in my opinion, it seems this is not correct.

Thank you!

コントリビューターガイド