Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possibly Incorrect Calculation of Perplexity in Pytorch Implementation #131

Open
shaan97 opened this issue Mar 18, 2021 · 0 comments
Open

Comments

@shaan97
Copy link

shaan97 commented Mar 18, 2021

First time ever posting an issue, apologies if I've written something incorrectly or missing obvious things.

On lines 82-88 of the file transformer-xl/pytorch/eval.py, the perplexity is being computed by computing the total loss and the total segment size.

mems = tuple()
        for idx, (data, target, seq_len) in enumerate(eval_iter):
            ret = model(data, target, *mems)
            loss, mems = ret[0], ret[1:]
            loss = loss.mean()
            total_loss += seq_len * loss.item()
            total_len += seq_len

Rather than adding to the total loss the term loss.sum(), the implementation instead multiplies the mean by seq_len. However when computing loss, there should only seq_len - 1 losses in the output of the model (in language modeling you predict the next token based on the previous tokens, so this excludes computing a loss value for the very first token).

(Compare this against the TF implementation in file transformer-xl/tf/train_gpu.py

  if len(tower_losses) > 1:
    loss = tf.add_n(tower_losses) / len(tower_losses)
  else:
    loss = tower_losses[0]

Here this issue is avoided because all losses are appended into a list tower_losses and then summed over and divided by the length of that list.)

This is subtle because it will make your perplexity value seem correct, but in actuality your perplexity computation is pretending to include one extra term. I think this is the correct implementation:

mems = tuple()
        for idx, (data, target, seq_len) in enumerate(eval_iter):
            ret = model(data, target, *mems)
            loss, mems = ret[0], ret[1:]
            loss = loss.mean()
            total_loss += (seq_len - 1) * loss.item()
            total_len += seq_len - 1

Or

mems = tuple()
        for idx, (data, target, seq_len) in enumerate(eval_iter):
            ret = model(data, target, *mems)
            loss, mems = ret[0], ret[1:]
            total_loss += loss.sum().item()
            total_len += seq_len - 1

Is this a bug? Am I missing something? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant