Skip to content

Commit

Permalink
Update utils_pyRCTD.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xwanaf authored Jul 12, 2023
1 parent d77ed81 commit 9c0c8a9
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/utils_pyRCTD.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def psd(H):
# P = eig[1] @ np.diag(np.clip(eig[0], a_min = epsilon, a_max = eig[0].max() + 10)) @ eig[1].T
return P

def solveWLS(S,B,S_mat,initialSol, nUMI, bulk_mode = False, constrain = False, likelihood_vars = None):
def solveWLS(S,B,S_mat,initialSol, nUMI, bulk_mode = False, constrain = False, likelihood_vars = None, solver = 'osqp'):
solution = initialSol.copy()
solution[solution < 0] = 0
prediction = np.absolute(S @ solution)
Expand All @@ -176,13 +176,13 @@ def solveWLS(S,B,S_mat,initialSol, nUMI, bulk_mode = False, constrain = False, l
bzero = -solution
alpha = 0.3
if constrain:
solution = solution + alpha*solve_qp(np.array(D_mat),-np.array(d_vec),-np.array(A),-np.array(bzero), np.ones(solution.shape[0]), 1 - solution.sum(), solver="osqp")
solution = solution + alpha*solve_qp(np.array(D_mat),-np.array(d_vec),-np.array(A),-np.array(bzero), np.ones(solution.shape[0]), 1 - solution.sum(), solver=solver)
else:
solution = solution + alpha*solve_qp(np.array(D_mat),-np.array(d_vec),-np.array(A),-np.array(bzero), solver="osqp")
solution = solution + alpha*solve_qp(np.array(D_mat),-np.array(d_vec),-np.array(A),-np.array(bzero), solver=solver)
return solution

def solveIRWLS_weights(S,B,nUMI, OLS=False, constrain = True, verbose = False,
n_iter = 50, MIN_CHANGE = .001, bulk_mode = False, solution = None, loggings = None, likelihood_vars = None):
n_iter = 50, MIN_CHANGE = .001, bulk_mode = False, solution = None, loggings = None, likelihood_vars = None, solver = 'osqp'):
if not bulk_mode:
K_val = likelihood_vars['K_val']
B = np.copy(B)
Expand All @@ -193,7 +193,7 @@ def solveIRWLS_weights(S,B,nUMI, OLS=False, constrain = True, verbose = False,
change = 1
changes = []
while (change > MIN_CHANGE) & (iterations < n_iter):
new_solution = solveWLS(S,B,S_mat,solution, nUMI,constrain=constrain, bulk_mode = bulk_mode, likelihood_vars = likelihood_vars)
new_solution = solveWLS(S,B,S_mat,solution, nUMI,constrain=constrain, bulk_mode = bulk_mode, likelihood_vars = likelihood_vars, solver = solver)
change = np.linalg.norm(new_solution-solution, 1)
if verbose:
loggings.info('Change: {}'.format(change))
Expand All @@ -208,8 +208,12 @@ def decompose_full_ray(args):
bulk_mode = False
verbose = False
n_iter = 50
results = solveIRWLS_weights(cell_type_profiles,bead,nUMI,OLS = OLS, constrain = constrain,
verbose = verbose, n_iter = n_iter, MIN_CHANGE = MIN_CHANGE, bulk_mode = bulk_mode, loggings = loggings, likelihood_vars = likelihood_vars)
try:
results = solveIRWLS_weights(cell_type_profiles,bead,nUMI,OLS = OLS, constrain = constrain,
verbose = verbose, n_iter = n_iter, MIN_CHANGE = MIN_CHANGE, bulk_mode = bulk_mode, loggings = loggings, likelihood_vars = likelihood_vars, solver = 'osqp')
except:
results = solveIRWLS_weights(cell_type_profiles,bead,nUMI,OLS = OLS, constrain = constrain,
verbose = verbose, n_iter = n_iter, MIN_CHANGE = MIN_CHANGE, bulk_mode = bulk_mode, loggings = loggings, likelihood_vars = likelihood_vars, solver = 'cvxopt')
return results


Expand Down Expand Up @@ -690,4 +694,4 @@ def check_pairs_type(cell_type_profiles, bead, UMI_tot, score_mat, min_score, my
if all_pairs_class and not all_pairs and (len(other_class) > 1):
for ty in other_class[2:len(other_class)]:
singlet_score = min(singlet_score, get_singlet_score(cell_type_profiles, bead, UMI_tot, ty, constrain, MIN_CHANGE = MIN_CHANGE, loggings = loggings, likelihood_vars = likelihood_vars))
return {'all_pairs': all_pairs, 'all_pairs_class': all_pairs_class, 'singlet_score': singlet_score}
return {'all_pairs': all_pairs, 'all_pairs_class': all_pairs_class, 'singlet_score': singlet_score}

0 comments on commit 9c0c8a9

Please sign in to comment.