Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphTrans-mean may be wrong #8

Open
LUOyk1999 opened this issue Oct 9, 2022 · 0 comments
Open

GraphTrans-mean may be wrong #8

LUOyk1999 opened this issue Oct 9, 2022 · 0 comments

Comments

@LUOyk1999
Copy link

LUOyk1999 commented Oct 9, 2022

Hello, Thanks for excellent work. But I have found some possible problems.
In the paper, the authors mention that "In Table 5, we tested several common methods to for sequence classification. The mean operation averages the output embeddings of the transformer to a single graph embedding; the last operation takes the last embedding in the output sequence as the graph embedding."

Table 5:
Model, Valid, Test
GraphTrans-mean, 0.1398, 0.1509

However, I observed the GraphTrans code and found that the author's implementation of mean could be wrong.
gnn_transformer.py, line 116-117:
elif self.pooling == "mean": h_graph = transformer_out.sum(0) / src_padding_mask.sum(-1, keepdim=True)
transformer_out.shape = (S, B, h_d), src_padding_mask.shape = (B, S)
The padding nodes information in transformer_out, and the authors do not remove them (do not unpad_batch) but sum directly.

I modified the mean operation, then redid the experiment and found that the result is improved than the one reported by the authors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant