Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tensor parallel features to RotatE #384

Open
wants to merge 5 commits into
base: OpenKE-PyTorch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ venv/
ENV/
env.bak/
venv.bak/

.DS_Store
benchmarks/
# Spyder project settings
.spyderproject
.spyproject
Expand Down
22 changes: 22 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"module": "torch.distributed.launch",
"request": "launch",
"console": "integratedTerminal",
"justMyCode": true,
"env": {"CUDA_VISIBLE_DEVICES":"0,1",
},
"args": [
"--nproc_per_node", "2",
"./train_transe_FB15K237.py"
]
}
]
}
5 changes: 4 additions & 1 deletion openke/config/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def __init__(self,
use_gpu = True,
opt_method = "sgd",
save_steps = None,
checkpoint_dir = None):
checkpoint_dir = None,
mode="single"
):

self.work_threads = 8
self.train_times = train_times
Expand All @@ -39,6 +41,7 @@ def __init__(self,
self.use_gpu = use_gpu
self.save_steps = save_steps
self.checkpoint_dir = checkpoint_dir
self.mode = mode

def train_one_step(self, data):
self.optimizer.zero_grad()
Expand Down
Empty file modified openke/module/model/RotatE.py
100755 → 100644
Empty file.
105 changes: 105 additions & 0 deletions openke/module/model/RotatE_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.autograd as autograd
import torch.nn as nn
from .Model import Model
import torch.distributed as dist

class DistributedRotatE(Model):

def __init__(self, ent_tot, rel_tot, dim = 100, margin = 6.0, epsilon = 2.0, world_size=2):
super(DistributedRotatE, self).__init__(ent_tot, rel_tot)

self.margin = margin
self.epsilon = epsilon

self.dim_e = dim * 2
self.dim_r = dim

self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim_e // world_size)
self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim_r // world_size)

self.ent_embedding_range = nn.Parameter(
torch.Tensor([(self.margin + self.epsilon) / (self.dim_e // world_size)]),
requires_grad=False
)

nn.init.uniform_(
tensor = self.ent_embeddings.weight.data,
a=-self.ent_embedding_range.item(),
b=self.ent_embedding_range.item()
)

self.rel_embedding_range = nn.Parameter(
torch.Tensor([(self.margin + self.epsilon) / (self.dim_r // world_size)]),
requires_grad=False
)

nn.init.uniform_(
tensor = self.rel_embeddings.weight.data,
a=-self.rel_embedding_range.item(),
b=self.rel_embedding_range.item()
)

self.margin = nn.Parameter(torch.Tensor([margin]))
self.margin.requires_grad = False

def _calc(self, h, t, r, mode):
pi = self.pi_const

re_head, im_head = torch.chunk(h, 2, dim=-1)
re_tail, im_tail = torch.chunk(t, 2, dim=-1)

phase_relation = r / (self.rel_embedding_range / pi)

re_relation = torch.cos(phase_relation)
im_relation = torch.sin(phase_relation)

re_head = re_head.view(-1, re_relation.shape[0], re_head.shape[-1]).permute(1, 0, 2)
re_tail = re_tail.view(-1, re_relation.shape[0], re_tail.shape[-1]).permute(1, 0, 2)
im_head = im_head.view(-1, re_relation.shape[0], im_head.shape[-1]).permute(1, 0, 2)
im_tail = im_tail.view(-1, re_relation.shape[0], im_tail.shape[-1]).permute(1, 0, 2)
im_relation = im_relation.view(-1, re_relation.shape[0], im_relation.shape[-1]).permute(1, 0, 2)
re_relation = re_relation.view(-1, re_relation.shape[0], re_relation.shape[-1]).permute(1, 0, 2)

if mode == "head_batch":
re_score = re_relation * re_tail + im_relation * im_tail
im_score = re_relation * im_tail - im_relation * re_tail
re_score = re_score - re_head
im_score = im_score - im_head
else:
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation
re_score = re_score - re_tail
im_score = im_score - im_tail

score = torch.stack([re_score, im_score], dim = 0)
score = score.norm(dim = 0).sum(dim = -1)
dist.all_reduce(score, dist.ReduceOp.SUM,)
return score.permute(1, 0).flatten()

def forward(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
mode = data['mode']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
score = self.margin - self._calc(h ,t, r, mode)
return score

def predict(self, data):
score = -self.forward(data)
return score.cpu().data.numpy()

def regularization(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
regul = (torch.mean(h ** 2) +
torch.mean(t ** 2) +
torch.mean(r ** 2)) / 3
return regul
20 changes: 16 additions & 4 deletions openke/module/model/TransE.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import torch.nn as nn
import torch.nn.functional as F
from .Model import Model
import torch.distributed as dist


class TransE(Model):

def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None):
def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None, world_size=2):
super(TransE, self).__init__(ent_tot, rel_tot)

self.dim = dim
Expand All @@ -14,15 +16,16 @@ def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, ma
self.norm_flag = norm_flag
self.p_norm = p_norm

self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)

self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim // world_size)
self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim // world_size)
if margin == None or epsilon == None:
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
self.embedding_range = nn.Parameter(
torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False

torch.Tensor([(self.margin + self.epsilon) / self.dim//world_size]), requires_grad=False
)
nn.init.uniform_(
tensor = self.ent_embeddings.weight.data,
Expand All @@ -45,9 +48,16 @@ def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, ma

def _calc(self, h, t, r, mode):
if self.norm_flag:
<<<<<<< HEAD
h = F.normalize(h, 2, -1)
r = F.normalize(r, 2, -1)
t = F.normalize(t, 2, -1)

=======
h = F.normalize(h)
r = F.normalize(r)
t = F.normalize(t)
>>>>>>> 642bba4563d441d11cefbbfb2dd94ef4deac2206
if mode != 'normal':
h = h.view(-1, r.shape[0], h.shape[-1])
t = t.view(-1, r.shape[0], t.shape[-1])
Expand All @@ -57,6 +67,8 @@ def _calc(self, h, t, r, mode):
else:
score = (h + r) - t
score = torch.norm(score, self.p_norm, -1).flatten()

dist.all_reduce(score, dist.ReduceOp.SUM,)
return score

def forward(self, data):
Expand Down
4 changes: 3 additions & 1 deletion openke/module/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .Analogy import Analogy
from .SimplE import SimplE
from .RotatE import RotatE
from .RotatE_dist import DistributedRotatE

__all__ = [
'Model',
Expand All @@ -25,5 +26,6 @@
'RESCAL',
'Analogy',
'SimplE',
'RotatE'
'RotatE',
'DistributedRotatE'
]
1 change: 1 addition & 0 deletions stats/running_cost_time.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
running time 22183.74623656273
1 change: 1 addition & 0 deletions stats/single_running_cost_time.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions stats/total_cost_time_rotate.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
total elapsed time 30921.93621945381total elapsed time 30921.948616981506
48 changes: 48 additions & 0 deletions train_rotate_FB15K237.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import openke
from openke.config import Trainer, Tester
from openke.module.model import RotatE
from openke.module.loss import SigmoidLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader

# dataloader for training
train_dataloader = TrainDataLoader(
in_path = "./benchmarks/FB15K237/",
batch_size = 2000,
threads = 8,
sampling_mode = "cross",
bern_flag = 0,
filter_flag = 1,
neg_ent = 64,
neg_rel = 0
)

# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link")

# define the model
rotate = RotatE(
ent_tot = train_dataloader.get_ent_tot(),
rel_tot = train_dataloader.get_rel_tot(),
dim = 1024,
margin = 6.0,
epsilon = 2.0,
)

# define the loss function
model = NegativeSampling(
model = rotate,
loss = SigmoidLoss(adv_temperature = 2),
batch_size = train_dataloader.get_batch_size(),
regul_rate = 0.0
)

# train the model
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 6000, alpha = 2e-5, use_gpu = True, opt_method = "adam", mode="single")
trainer.run()
rotate.save_checkpoint('./checkpoint/single_rotate.ckpt')

# test the model
rotate.load_checkpoint('./checkpoint/single_rotate.ckpt')
tester = Tester(model = rotate, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)
71 changes: 71 additions & 0 deletions train_rotate_FB15K237_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import openke
from openke.config import Trainer, Tester
from openke.module.model import DistributedRotatE
from openke.module.loss import SigmoidLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader
import torch
import torch.distributed as dist
import argparse
import time
import numpy as np
import os

torch.manual_seed(1234)
np.random.seed(1234)

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
args = parser.parse_args()
local_rank = args.local_rank

total_start_time = time.time()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', )
world_size = dist.get_world_size()

# dataloader for training
train_dataloader = TrainDataLoader(
in_path = "./benchmarks/FB15K237/",
batch_size = 2000,
threads = 8,
sampling_mode = "cross",
bern_flag = 0,
filter_flag = 1,
neg_ent = 64,
neg_rel = 0
)

# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link")

# define the model
rotate = DistributedRotatE(
ent_tot = train_dataloader.get_ent_tot(),
rel_tot = train_dataloader.get_rel_tot(),
dim = 1024,
margin = 6.0,
epsilon = 2.0,
world_size = world_size
)

# define the loss function
model = NegativeSampling(
model = rotate,
loss = SigmoidLoss(adv_temperature = 2),
batch_size = train_dataloader.get_batch_size(),
regul_rate = 0.0
)

# train the model
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 6000, alpha = 2e-5, use_gpu = True, opt_method = "adam", save_steps=10000)
trainer.run()
total_end_time = time.time()
with open("./total_cost_time_rotate.txt", "a+") as f:
f.writelines(["total elapsed time {}".format(total_end_time - total_start_time)])
rotate.save_checkpoint('./checkpoint/rotate.ckpt')

# test the model
rotate.load_checkpoint('./checkpoint/rotate.ckpt')
tester = Tester(model = rotate, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)
Loading