Skip to content

Commit

Permalink
closes #51
Browse files Browse the repository at this point in the history
  • Loading branch information
LBlend committed Apr 10, 2022
1 parent 3bd7f17 commit 7e8694c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 2 deletions.
7 changes: 6 additions & 1 deletion build.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ fi
echo "Downloading NLTK stopwords and punctuation packages..."
python3 -c "import nltk; nltk.download('stopwords'); nltk.download('punkt')"

if [ ! -e 'src/bin/bayes_model.pkl' ]; then
if [ ! -e 'src/bin/bayes_model_sk.pkl' ]; then
echo "Training Bayes model..."
python3 ./src/scripts/train_bayes.py
fi

if [ ! -e 'src/bin/logreg_model_sk.pkl' ]; then
echo "Training LogReg model..."
python3 ./src/scripts/train_logreg.py
fi

if [ ! -d 'src/bin/rnn' ]; then
echo "Training RNN model..."
python3 ./src/scripts/train_rnn.py
Expand Down
14 changes: 13 additions & 1 deletion src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
with open("src/bin/bayes_model_sk.pkl", "rb") as f:
bayes_model = pickle.load(f)

with open("src/bin/logreg_model_sk.pkl", "rb") as f:
logreg_model = pickle.load(f)

with open("src/bin/vectorizer.pkl", "rb") as f:
vectorizer = pickle.load(f)

Expand All @@ -25,14 +28,23 @@ def predict_bayes(text: str) -> dict[str, float]:
}


def predict_logreg(text: str) -> dict[str, float]:
features = vectorizer.transform([text])
probs = logreg_model.predict_proba(features)[0]
return {
"male": probs[1],
"female": probs[0],
}


def predict_rnn(text: str) -> dict[str, float]:
pred_arr = rnn_model.predict(np.array([text]))
pred = float(pred_arr[0, 0])
# F = 0, M = 1, P(!A) == 1 - P(A)
return {"male": pred, "female": 1 - pred}


pred_funcs = {"bayes": predict_bayes, "rnn": predict_rnn}
pred_funcs = {"bayes": predict_bayes, "rnn": predict_rnn, "logreg": predict_logreg}


def predict(text: str, classifier: str = "bayes") -> dict[str, str | dict[str, float]]:
Expand Down
64 changes: 64 additions & 0 deletions src/scripts/train_logreg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import nltk
import numpy as np
import os
import pickle
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegressionCV

MODEL_PATH = "src/bin/logreg_model_sk.pkl"
VECTORIZER_PATH = "src/bin/vectorizer.pkl"

STOPWORDS = nltk.corpus.stopwords.words("norwegian")
np.random.seed(42)


def read_file(path: str) -> str:
with open(path) as f:
content = f.read()
return content


def _load_corpus() -> tuple[list[str] | np.ndarray]:
dir_f = "corpus/data/train/F/"
dir_m = "corpus/data/train/M/"
files_f = [dir_f + i for i in os.listdir(dir_f)]
files_m = [dir_m + i for i in os.listdir(dir_m)]
labels_f = np.full(len(files_f), "F")
labels_m = np.full(len(files_m), "M")

raw_data = list(map(read_file, files_f + files_m))
labels = np.concatenate((labels_f, labels_m))
return raw_data, labels


def _save_model(model: LogisticRegressionCV, path: str) -> None:
with open(path, "wb+") as f:
pickle.dump(model, f)


def train() -> None:
raw_data, labels = _load_corpus()

vectorizer = CountVectorizer(
stop_words=STOPWORDS,
ngram_range=[1, 3], # for usage of trigrams and bigrams
max_features=5000,
)

features = vectorizer.fit_transform(raw_data)
# shuffle data in case it is not permutation invariant
perms = np.random.permutation(len(labels))
# Does CSRMatrix handle this?
features = features[perms]
labels = labels[perms]

clf = LogisticRegressionCV(multi_class="multinomial", cv=5, max_iter=5000)

clf.fit(features, labels)

_save_model(clf, MODEL_PATH)
_save_model(vectorizer, VECTORIZER_PATH)


if __name__ == "__main__":
train()
24 changes: 24 additions & 0 deletions src/scripts/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,31 @@ def predict_on_str(text: str) -> int:
return lambda x: [predict_on_str(s) for s in x]


def get_predict_logreg() -> list[int]:
model_path = "src/bin/logreg_model_sk.pkl"
vectorizer_path = "src/bin/vectorizer.pkl"

with open(model_path, "rb") as f:
bayes_model = pickle.load(f)

with open(vectorizer_path, "rb") as f:
vectorizer = pickle.load(f)

def predict_on_str(text: str) -> int:
text = text.decode("UTF8")
features = vectorizer.transform([text])
pred = bayes_model.predict(features)
return 0 if pred[0] == "F" else 1

return lambda x: [predict_on_str(s) for s in x]


if __name__ == "__main__":
dev_set, test_set = get_dev_test()

predict_rnn = get_predict_rnn("src/bin/rnn")
predict_bayes = get_predict_bayes()
predict_logreg = get_predict_logreg()

for name, (X, y) in [("dev", dev_set), ("test", test_set)]:
X = X.numpy()
Expand All @@ -53,3 +73,7 @@ def predict_on_str(text: str) -> int:
y_pred = predict_bayes(X)
print("Results for", name, "with bayes:")
print(classification_report(y, y_pred))

y_pred = predict_logreg(X)
print("Results for", name, "with logreg:")
print(classification_report(y, y_pred))

0 comments on commit 7e8694c

Please sign in to comment.