diff --git a/model_training_banana.py b/model_training_banana.py index 01cf360..0e55206 100644 --- a/model_training_banana.py +++ b/model_training_banana.py @@ -26,7 +26,7 @@ def get(self, idx): print(f"Number of node features: {dataset.num_node_features}") -seed_gnn = 0 +seed_gnn = int(sys.argv[1]) torch.manual_seed(seed_gnn) dataset_shuffle = dataset.shuffle()