Skip to content

Commit

Permalink
support balanced class_weight #141
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed May 6, 2019
1 parent 8466021 commit 6e28da8
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 11 deletions.
4 changes: 2 additions & 2 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ The usage of thundersvm scikit interface is similar to sklearn.svm.
*probability*: boolean, optional(default=False)\
whether to train a SVC or SVR model for probability estimates, True or False

*class_weight*: {dict}, optional(default=None)\
set the parameter C of class i to weight*C, for C-SVC
*class_weight*: {dict, 'balanced'}, optional(default=None)\
set the parameter C of class i to weight*C, for C-SVC. If not given, all classes are supposed to have weight one. The “balanced” mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as ```n_samples / (n_classes * np.bincount(y))```

*shrinking*: boolean, optional (default=False, not supported yet for True)\
whether to use the shrinking heuristic.
Expand Down
52 changes: 44 additions & 8 deletions python/thundersvm/thundersvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,30 @@ def _dense_fit(self, X, y, solver_type, kernel):
if self.class_weight is None:
weight_size = 0
self.class_weight = dict()
weight_label = (c_int * weight_size)()
weight_label[:] = list(self.class_weight.keys())
weight = (c_float * weight_size)()
weight[:] = list(self.class_weight.values())
elif self.class_weight == 'balanced':
y_unique = np.unique(y)
y_count = np.bincount(y.astype(int))
weight_label_list = []
weight_list = []
for n in range(0, len(y_count)):
if y_count[n] != 0:
weight_label_list.append(n)
weight_list.append(samples/(len(y_unique)*y_count[n]))
weight_size=len(weight_list)
weight_label = (c_int * weight_size)()
weight_label[:] = weight_label_list
weight = (c_float * weight_size)()
weight[:] = weight_list
else:
weight_size = len(self.class_weight)
weight_label = (c_int * weight_size)()
weight_label[:] = list(self.class_weight.keys())
weight = (c_float * weight_size)()
weight[:] = list(self.class_weight.values())
weight_label = (c_int * weight_size)()
weight_label[:] = list(self.class_weight.keys())
weight = (c_float * weight_size)()
weight[:] = list(self.class_weight.values())

n_features = (c_int * 1)()
n_classes = (c_int * 1)()
Expand Down Expand Up @@ -228,12 +246,30 @@ def _sparse_fit(self, X, y, solver_type, kernel):
if self.class_weight is None:
weight_size = 0
self.class_weight = dict()
weight_label = (c_int * weight_size)()
weight_label[:] = list(self.class_weight.keys())
weight = (c_float * weight_size)()
weight[:] = list(self.class_weight.values())
elif self.class_weight == 'balanced':
y_unique = np.unique(y)
y_count = np.bincount(y.astype(int))
weight_label_list = []
weight_list = []
for n in range(0, len(y_count)):
if y_count[n] != 0:
weight_label_list.append(n)
weight_list.append(X.shape[0]/(len(y_unique)*y_count[n]))
weight_size=len(weight_list)
weight_label = (c_int * weight_size)()
weight_label[:] = weight_label_list
weight = (c_float * weight_size)()
weight[:] = weight_list
else:
weight_size = len(self.class_weight)
weight_label = (c_int * weight_size)()
weight_label[:] = list(self.class_weight.keys())
weight = (c_float * weight_size)()
weight[:] = list(self.class_weight.values())
weight_label = (c_int * weight_size)()
weight_label[:] = list(self.class_weight.keys())
weight = (c_float * weight_size)()
weight[:] = list(self.class_weight.values())

n_features = (c_int * 1)()
n_classes = (c_int * 1)()
Expand Down
2 changes: 1 addition & 1 deletion src/thundersvm/thundersvm-scikit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ extern "C" {
param_cmd.max_mem_size = static_cast<size_t>(max(max_mem_size, 0)) << 20;
if(weight_size != 0) {
param_cmd.nr_weight = weight_size;
param_cmd.weight = (float_type *) malloc(weight_size * sizeof(float_type));
param_cmd.weight = (float_type *) malloc(weight_size * sizeof(double));
param_cmd.weight_label = (int *) malloc(weight_size * sizeof(int));
for (int i = 0; i < weight_size; i++) {
param_cmd.weight[i] = weight[i];
Expand Down

0 comments on commit 6e28da8

Please sign in to comment.