Skip to content

Commit

Permalink
fix bug load_from_file #148
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed May 14, 2019
1 parent 498a0b5 commit fab6844
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/thundersvm/thundersvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,19 +465,19 @@ def load_from_file(self, path):
n_feature = (c_int * 1)()
thundersvm.get_sv_max_index(c_void_p(self.model), n_feature)
self.n_features = n_feature[0]

csr_row = (c_int * (self.n_sv + 1))()
csr_col = (c_int * (self.n_sv * self.n_features))()
csr_data = (c_float * (self.n_sv * self.n_features))()
data_size = (c_int * 1)()
thundersvm.get_sv(csr_row, csr_col, csr_data, data_size, c_void_p(self.model))
sv_indices = (c_int * self.n_sv)()
thundersvm.get_sv(csr_row, csr_col, csr_data, data_size, sv_indices, c_void_p(self.model))
self.row = np.array([csr_row[index] for index in range(0, self.n_sv + 1)])
self.col = np.array([csr_col[index] for index in range(0, data_size[0])])
self.data = np.array([csr_data[index] for index in range(0, data_size[0])])
self.support_vectors_ = sp.csr_matrix((self.data, self.col, self.row))
# if self._sparse == False:
# self.support_vectors_ = self.support_vectors_.toarray(order = 'C')

self.support_ = np.array([sv_indices[index] for index in range(0, self.n_sv)]).astype(int)
dual_coef = (c_float * ((self.n_classes - 1) * self.n_sv))()
thundersvm.get_coef(dual_coef, self.n_classes, self.n_sv, c_void_p(self.model))
self.dual_coef_ = np.array([dual_coef[index] for index in range(0, (self.n_classes - 1) * self.n_sv)]).astype(float)
Expand Down

0 comments on commit fab6844

Please sign in to comment.