Skip to content

Commit

Permalink
added dropouts to discrepancy
Browse files Browse the repository at this point in the history
  • Loading branch information
daneschi committed Jan 25, 2024
1 parent 804b7bc commit 4894190
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ results/
*.npz
*.sur
*.csv
.linfa/
38 changes: 29 additions & 9 deletions linfa/discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, model_name,
var_grid_out,
dnn_arch=None,
dnn_activation='relu',
dnn_dropout=None,
model_folder='./',
surrogate=None,
device='cpu'):
Expand All @@ -32,6 +33,8 @@ def __init__(self, model_name,
self.input_size = None
self.output_size = None
self.dnn_arch = None
self.dnn_activation = None
self.dnn_dropout = None
self.is_trained = None
self.lf_model = None
self.var_grid_in = None
Expand All @@ -48,6 +51,8 @@ def __init__(self, model_name,
self.input_size = input_size
self.output_size = output_size
self.dnn_arch = dnn_arch
self.dnn_activation = dnn_activation
self.dnn_dropout = dnn_dropout
self.is_trained = False

# Assign LF model
Expand All @@ -70,7 +75,7 @@ def __init__(self, model_name,
self.var_out_std = torch.std(var_grid_out)

# Create surrogate
self.surrogate = FNN(input_size, output_size, arch=self.dnn_arch, device=self.device, init_zero=True) if surrogate is None else surrogate
self.surrogate = FNN(input_size, output_size, arch=self.dnn_arch, device=self.device, init_zero=True, dropout=dnn_dropout) if surrogate is None else surrogate

def surrogate_save(self):
"""Save surrogate model to [self.name].sur and [self.name].npz
Expand Down Expand Up @@ -122,6 +127,7 @@ def update(self, batch_x, max_iters=10000, lr=0.01, lr_exp=0.999, record_interva
print('--- Training model discrepancy')
print('')

# Set it as trained
self.is_trained = True

# LF model output at the current batch
Expand All @@ -147,6 +153,7 @@ def update(self, batch_x, max_iters=10000, lr=0.01, lr_exp=0.999, record_interva
for i in range(max_iters):
# Set surrogate in training mode
self.surrogate.train()

# Surrogate returns a table with rows as batches and columns as variables considered
disc = self.surrogate(var_grid)

Expand Down Expand Up @@ -176,8 +183,14 @@ def update(self, batch_x, max_iters=10000, lr=0.01, lr_exp=0.999, record_interva
print('')
print('--- Surrogate model pre-train complete')
print('')
# Save if needed
if store:
self.surrogate_save()
# Put it in eval model if no dropouts are present
if(self.dnn_dropout is not None):
self.surrogate.train()
else:
self.surrogate.eval()

def forward(self, var):
"""Function to evaluate the surrogate
Expand All @@ -200,7 +213,7 @@ def forward(self, var):
else:
return res

def test_surrogate():
def test_discrepancy():

import matplotlib.pyplot as plt
from linfa.models.discrepancy_models import PhysChem
Expand All @@ -213,7 +226,7 @@ def test_surrogate():
model = PhysChem(var_grid)

# Generate true data
model.genDataFile(dataFileNamePrefix='observations', use_true_model=True, store=True, num_observations=10)
model.genDataFile(dataFileNamePrefix='observations', use_true_model=True, store=True, num_observations=3)

# Get data from true model at the same TP conditions
var_data = np.loadtxt('observations.csv',skiprows=1,delimiter=',')
Expand All @@ -222,22 +235,29 @@ def test_surrogate():

# Define emulator and pre-train on global grid
discrepancy = Discrepancy(model_name='discrepancy_test',
lf_model=model.solve_lf,
lf_model=model.solve_t,
input_size=2,
output_size=1,
var_grid_in=var_data_in,
var_grid_out=var_data_out)

var_grid_out=var_data_out,
dnn_arch=[64,64],
dnn_activation='relu',
dnn_dropout=[0.2,0.5],
activation='silu')

# Create a batch of samples for the calibration parameters
batch_x = model.defParams

# Update the discrepancy model
discrepancy.update(batch_x, max_iters=1000, lr=0.001, lr_exp=0.9999, record_interval=100)
discrepancy.update(batch_x, max_iters=10000, lr=0.001, lr_exp=0.9999, record_interval=100)

# Plot discrepancy
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(model.var_in[:,0].detach().numpy(), model.var_in[:,1].detach().numpy(), discrepancy.forward(model.var_in).detach().numpy(), marker='o')
for loopA in range(50):
ax.scatter(model.var_in[:,0].detach().numpy(), model.var_in[:,1].detach().numpy(), model.solve_t(batch_x)+discrepancy.forward(model.var_in).detach().numpy(), color='blue', marker='o')
for loopA in range(var_data_out.size(1)):
ax.scatter(model.var_in[:,0].detach().numpy(), model.var_in[:,1].detach().numpy(), var_data_out[:,loopA].detach().numpy(), color='red', marker='D', s=5)
ax.set_xlabel('Temperature')
ax.set_ylabel('Pressure')
ax.set_zlabel('Coverage')
Expand All @@ -246,5 +266,5 @@ def test_surrogate():
# TEST SURROGATE
if __name__ == '__main__':

test_surrogate()
test_discrepancy()

2 changes: 2 additions & 0 deletions linfa/eval_model_from_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def eval_model(exp_chkpt_file,nf_chkpt_file,discr_chkpt_file,num_calib_samples,t
else:
res_lf = exp.model.solve_t(exp.transform.forward(xkk))

# CURRENTLY NO NOISE IS ADDED, NEED TO BE IMPLEMENTED IF APPROPRIATE!!!

# return
return res_lf + res_discr

Expand Down
38 changes: 32 additions & 6 deletions linfa/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class FNN(nn.Module):
"""Fully Connected Neural Network"""

def __init__(self, input_size, output_size, arch=None, activation='relu', device='cpu',init_zero=False):
def __init__(self, input_size, output_size, arch=None, activation='relu', device='cpu', init_zero=False, dropout=None):
"""
Args:
input_size (int): Input size for FNN
Expand All @@ -32,6 +32,12 @@ def __init__(self, input_size, output_size, arch=None, activation='relu', device
self.fc.append(nn.Linear(self.neuron_size[layer-1], self.neuron_size[layer]).to(device))
# Assign output layer
self.fc.append(nn.Linear(self.neuron_size[len(self.neuron_size)-1], output_size).to(device))
# Add dropout layers
self.dropout = dropout
if(self.dropout is not None):
self.dropout_visible = nn.Dropout(p=self.dropout[0])
self.dropout_hidden = nn.Dropout(p=self.dropout[1])


def forward(self, x):
"""
Expand All @@ -41,17 +47,37 @@ def forward(self, x):
Returns:
torch.Tensor. Assumed to be a batch.
"""
if(self.dropout is not None):

x = self.dropout_visible(x)

for loopA in range(len(self.fc)-1):
if(self.activation == 'relu'):
x = F.relu(self.fc[loopA](x))

if(self.activation == 'relu'):

if(self.dropout is not None):
x = self.dropout_hidden(F.relu(self.fc[loopA](x)))
else:
x = F.relu(self.fc[loopA](x))

elif(self.activation == 'silu'):
x = F.silu(self.fc[loopA](x))

if(self.dropout is not None):
x = self.dropout_hidden(F.silu(self.fc[loopA](x)))
else:
x = F.silu(self.fc[loopA](x))

elif(self.activation == 'tanh'):
x = F.tanh(self.fc[loopA](x))

if(self.dropout is not None):
x = self.dropout_hidden(F.tanh(self.fc[loopA](x)))
else:
x = F.tanh(self.fc[loopA](x))

else:
print('Invalid activation string.')
exit(-1)

# Last layer with linear activation
x = self.fc[len(self.fc)-1](x)
x = self.fc[-1](x)
return x
4 changes: 3 additions & 1 deletion linfa/models/discrepancy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(self, var_inputs):
## model constants
self.RConst = torch.tensor(8.314) # universal gas constant (J/ mol/ K)
self.data = None # dataset of model
self.stdRatio = 0.05 # standard deviation ratio
# self.stdRatio = 0.05 # standard deviation ratio
# TEMP TEST WITH LARGE NOISE ON DISCREPANCY - PUT IT BACK WHEN TEST FINISHED!!!
self.stdRatio = 0.1 # standard deviation ratio
self.defOut = self.solve_t(self.defParams)

def solve_t(self, cal_inputs):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"torch==1.13.1",
"numpy==1.21.6",
"matplotlib==3.5.3",
"dill",
'tomli; python_version < "3.11"',
]
requires-python = ">=3.7"
Expand Down

0 comments on commit 4894190

Please sign in to comment.