Skip to content

Commit

Permalink
Merge pull request #90 from AnFreTh/develop
Browse files Browse the repository at this point in the history
prepro fixes
  • Loading branch information
AnFreTh authored Aug 22, 2024
2 parents 6a78f8f + e78ed8b commit 6a20cfe
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 38 deletions.
23 changes: 11 additions & 12 deletions stream_topic/preprocessor/_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TextPreprocessor:
remove_words_with_numbers : bool, optional
Whether to remove words containing numbers from the text data (default is False).
remove_words_with_special_chars : bool, optional
Whether to remove words containing special characters from the text data (default is False).
Whether to remove words containing special characters from the text data (default is False).
"""

Expand All @@ -73,15 +73,18 @@ def __init__(self, **kwargs):
self.remove_html_tags = kwargs.get("remove_html_tags", True)
self.remove_special_chars = kwargs.get("remove_special_chars", True)
self.remove_accents = kwargs.get("remove_accents", True)
self.custom_stopwords = set(kwargs.get("custom_stopwords", []))
self.custom_stopwords = (
set(kwargs.get("custom_stopwords", []))
if kwargs.get("custom_stopwords")
else set()
)
self.detokenize = kwargs.get("detokenize", False)
self.min_word_freq = kwargs.get("min_word_freq", 2)
self.max_word_freq = kwargs.get("max_word_freq", None)
self.min_word_length = kwargs.get("min_word_length", 3)
self.max_word_length = kwargs.get("max_word_length", None)
self.dictionary = set(kwargs.get("dictionary", []))
self.remove_words_with_numbers = kwargs.get(
"remove_words_with_numbers", False)
self.remove_words_with_numbers = kwargs.get("remove_words_with_numbers", False)
self.remove_words_with_special_chars = kwargs.get(
"remove_words_with_special_chars", False
)
Expand Down Expand Up @@ -186,23 +189,19 @@ def _clean_text(self, text):
]

if self.min_word_length is not None:
words = [word for word in words if len(
word) >= self.min_word_length]
words = [word for word in words if len(word) >= self.min_word_length]

if self.max_word_length is not None:
words = [word for word in words if len(
word) <= self.max_word_length]
words = [word for word in words if len(word) <= self.max_word_length]

if self.dictionary != set():
words = [word for word in words if word in self.dictionary]

if self.remove_words_with_numbers:
words = [word for word in words if not any(
char.isdigit() for char in word)]
words = [word for word in words if not any(char.isdigit() for char in word)]

if self.remove_words_with_special_chars:
words = [word for word in words if not re.search(
r"[^a-zA-Z0-9\s]", word)]
words = [word for word in words if not re.search(r"[^a-zA-Z0-9\s]", word)]

if self.detokenize:
text = TreebankWordDetokenizer().detokenize(words)
Expand Down
7 changes: 3 additions & 4 deletions stream_topic/preprocessor/_tf_idf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def c_tf_idf(documents, m, ngram_range=(1, 1)):
w = t.sum(axis=1)

# Suppress divide by zero warning
with np.errstate(divide='ignore', invalid='ignore'):
with np.errstate(divide="ignore", invalid="ignore"):
tf = np.divide(t.T, w)
if np.any(np.isnan(tf)) or np.any(np.isinf(tf)):
logger.warning("NaNs or inf in tf matrix")
Expand All @@ -34,7 +34,7 @@ def c_tf_idf(documents, m, ngram_range=(1, 1)):
return tf_idf, count


def extract_tfidf_topics(tf_idf, count, docs_per_topic, n=10):
def extract_tfidf_topics(tf_idf, count, docs_per_topic, n=100):
"""class based tf_idf retrieval from cluster of documents
Args:
Expand All @@ -51,8 +51,7 @@ def extract_tfidf_topics(tf_idf, count, docs_per_topic, n=10):
tf_idf_transposed = tf_idf.T
indices = tf_idf_transposed.argsort()[:, -n:]
top_n_words = {
label: [((words[j]), (tf_idf_transposed[i][j]))
for j in indices[i]][::-1]
label: [((words[j]), (tf_idf_transposed[i][j])) for j in indices[i]][::-1]
for i, label in enumerate(labels)
}

Expand Down
47 changes: 25 additions & 22 deletions stream_topic/visuals/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from ..models.abstract_helper_models.base import BaseModel, TrainingStatus
from ..utils import TMDataset
from ._interactive import (_visualize_topic_model_2d,
_visualize_topic_model_3d, _visualize_topics_2d,
_visualize_topics_3d)
from ._interactive import (
_visualize_topic_model_2d,
_visualize_topic_model_3d,
_visualize_topics_2d,
_visualize_topics_3d,
)
from ._octis_visuals import OctisWrapperVisualModel


Expand Down Expand Up @@ -43,7 +46,7 @@ def visualize_topics_as_wordclouds(
hasattr(model, "topic_dict") and model._status == TrainingStatus.SUCCEEDED
), "Model must have been trained with topics extracted."

topics = model.get_topics()
topics = model.topic_dict

for topic_id, topic_words in topics.items():
# Generate a word frequency dictionary for the topic
Expand Down Expand Up @@ -153,24 +156,24 @@ def visualize_topics(
use_average: bool = True,
):
"""
Visualize topics in either 2D or 3D space using UMAP, t-SNE, or PCA dimensionality reduction techniques.
Args:
model (AbstractModel): The trained topic model instance.
model_output (dict, optional): The output of the topic model, typically including topic-word distributions and document-topic distributions. Required if the model does not have an 'output' attribute.
dataset (TMDataset, optional): The dataset used for training the topic model. Required if the model does not have an 'output' attribute.
three_dim (bool, optional): Flag to visualize in 3D if True, otherwise in 2D. Defaults to False.
reduce_first (bool, optional): Indicates whether to perform dimensionality reduction on embeddings before computing topic centroids. Defaults to False.
reducer (str, optional): Choice of dimensionality reduction technique. Supported values are 'umap', 'tsne', and 'pca'. Defaults to 'umap'.
port (int, optional): The port number on which the visualization dashboard will run. Defaults to 8050.
embedding_model_name (str, optional): Name of the embedding model used for generating document embeddings. Defaults to "all-MiniLM-L6-v2".
embeddings_folder_path (str, optional): Path to the folder containing precomputed embeddings. If not provided, embeddings will be computed on the fly.
embeddings_file_path (str, optional): Path to the file containing precomputed embeddings. If not provided, embeddings will be computed on the fly.
Returns:
None
The function launches a Dash server to visualize the topic model.
Visualize topics in either 2D or 3D space using UMAP, t-SNE, or PCA dimensionality reduction techniques.
Args:
model (AbstractModel): The trained topic model instance.
model_output (dict, optional): The output of the topic model, typically including topic-word distributions and document-topic distributions. Required if the model does not have an 'output' attribute.
dataset (TMDataset, optional): The dataset used for training the topic model. Required if the model does not have an 'output' attribute.
three_dim (bool, optional): Flag to visualize in 3D if True, otherwise in 2D. Defaults to False.
reduce_first (bool, optional): Indicates whether to perform dimensionality reduction on embeddings before computing topic centroids. Defaults to False.
reducer (str, optional): Choice of dimensionality reduction technique. Supported values are 'umap', 'tsne', and 'pca'. Defaults to 'umap'.
port (int, optional): The port number on which the visualization dashboard will run. Defaults to 8050.
embedding_model_name (str, optional): Name of the embedding model used for generating document embeddings. Defaults to "all-MiniLM-L6-v2".
embeddings_folder_path (str, optional): Path to the folder containing precomputed embeddings. If not provided, embeddings will be computed on the fly.
embeddings_file_path (str, optional): Path to the file containing precomputed embeddings. If not provided, embeddings will be computed on the fly.
Returns:
None
The function launches a Dash server to visualize the topic model.
"""
if not isinstance(model, BaseModel):
Expand Down

0 comments on commit 6a20cfe

Please sign in to comment.