diff --git a/README.md b/README.md index f5cc755d99..d4b06d585a 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,16 @@ -[![PyPI](tbd)](tbd) -![PyPI - Downloads](tbd) -[![docs build](https://readthedocs.org/projects/stream/badge/?version=latest)](tbd) -[![docs](https://img.shields.io/badge/docs-latest-blue)](tbd) +[![PyPI](https://img.shields.io/pypi/v/stream_topic)](https://pypi.org/project/stream_topic) +![PyPI - Downloads](https://img.shields.io/pypi/dm/stream_topic) +[![docs build](https://readthedocs.org/projects/stream_topic/badge/?version=latest)](https://stream_topic.readthedocs.io/en/latest/?badge=latest) +[![docs](https://img.shields.io/badge/docs-latest-blue)](https://stream-topic.readthedocs.io/en/latest/index.html) [![open issues](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/AnFreTh/STREAM/issues) -[📘Documentation](tbd) | -[🛠️Installation](#installation) | -[Models](#available-models) | +[📘Documentation](https://stream-topic.readthedocs.io/en/latest/index.html) | +[🛠️Installation](https://stream-topic.readthedocs.io/en/latest/installation.html) | +[Models](https://stream-topic.readthedocs.io/en/latest/api/models/index.html) | [🤔Report Issues](https://github.com/AnFreTh/STREAM/issues) diff --git a/docs/conf.py b/docs/conf.py index 2cc8feb950..ba90d793ba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,6 +4,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import os + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import pathlib @@ -11,8 +12,7 @@ sys.path.insert(0, os.path.abspath(".")) sys.path.insert(0, os.path.abspath("../")) -sys.path.insert(1, os.path.dirname( - os.path.abspath("../")) + os.sep + "stream") +sys.path.insert(1, os.path.dirname(os.path.abspath("../")) + os.sep + "stream") project = "stream_topic" author = "Anton Frederik Thielmann" @@ -37,7 +37,7 @@ "sphinx_copybutton", "sphinx_togglebutton", "nbsphinx", - 'myst_nb', + "myst_nb", "IPython.sphinxext.ipython_console_highlighting", "IPython.sphinxext.ipython_directive", # "myst_parser", @@ -64,7 +64,6 @@ "gensim", "nltk", "langdetect", - "octis", "loguru", "scipy", "community", @@ -75,9 +74,7 @@ "umap_learn", "dash", "optuna", - "optuna-integration" - "" - + "optuna-integration" "", ] # Add any paths that contain templates here, relative to this directory. @@ -123,8 +120,8 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "sphinx_book_theme" -html_static_path = ['_static'] -html_css_files = ['css/stream_theme.css'] +html_static_path = ["_static"] +html_css_files = ["css/stream_theme.css"] # html_css_files = ['custom.css'] # html_js_files = ['custom.js'] @@ -161,12 +158,12 @@ # -- Options for myst ---------------------------------------------- # uncomment line below to avoid running notebooks during development -nb_execution_mode = 'off' +nb_execution_mode = "off" # Notebook cell execution timeout; defaults to 30. nb_execution_timeout = 100 # List of patterns, relative to source directory, that match notebook # files that will not be executed. -myst_enable_extensions = ['dollarmath'] +myst_enable_extensions = ["dollarmath"] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False @@ -178,4 +175,6 @@ # other MyST extensions ] -mathjax_path = "https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" +mathjax_path = ( + "https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML" +) diff --git a/examples/KmeansTM.ipynb b/examples/KmeansTM.ipynb deleted file mode 100644 index ebd0e293e4..0000000000 --- a/examples/KmeansTM.ipynb +++ /dev/null @@ -1,302 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Import and load dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you have not installed the package and work inside the repository, use the following two lines of code to make the example word:\n", - "\n", - "import sys\n", - "sys.path.append(\"..\")" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\torch\\_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n", - "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\dash\\_jupyter.py:31: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n", - " _dash_comm = Comm(target_name=\"dash\")\n" - ] - } - ], - "source": [ - "from stream.utils import TMDataset\n", - "from stream.models import KmeansTM\n", - "\n", - "dataset = TMDataset()\n", - "dataset.fetch_dataset(\"20NewsGroup\")" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'en'" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset.language" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Preprocessing documents: 0%| | 0/2500 [00:00\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from stream.visuals import visualize_topic_model, visualize_topics\n", - "visualize_topic_model(model, dataset=dataset, port=8051)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "visualize_topics(model, port=8052)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "STREAM_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/downstream_task.ipynb b/examples/downstream_task.ipynb deleted file mode 100644 index ac52f877aa..0000000000 --- a/examples/downstream_task.ipynb +++ /dev/null @@ -1,152 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Import and load dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from stream.data_utils import TMDataset\n", - "from stream.models import CEDC\n", - "\n", - "dataset = TMDataset()\n", - "dataset.fetch_dataset(\"Spotify_random\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train the model\n", - "If embeddings for the model have been created before, they will not be created again for faster computation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = CEDC(num_topics=5) # Create model\n", - "model_output = model.train_model(dataset) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize topics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from stream.visuals import visualize_topic_model\n", - "visualize_topic_model(model, port=8053)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Fit Downstream Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from pytorch_lightning import Trainer\n", - "from stream.NAM import DownstreamModel\n", - "\n", - "# Instantiate the DownstreamModel\n", - "downstream_model = DownstreamModel(\n", - " trained_topic_model=model,\n", - " target_column='popularity', # Target variable\n", - " task='regression', # or 'classification'\n", - " dataset=dataset, \n", - " batch_size=512,\n", - " lr=0.0005\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Train the downstream model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use PyTorch Lightning's Trainer to train and validate the model\n", - "trainer = Trainer(max_epochs=50)\n", - "trainer.fit(downstream_model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.validate()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualize Feature and Topic contributions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from stream.visuals import plot_downstream_model\n", - "\n", - "plot_downstream_model(downstream_model)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "ExpandedTM_venv", - "language": "python", - "name": "expandedtm_venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.18" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/load_own_data.ipynb b/examples/load_own_data.ipynb deleted file mode 100644 index c3bdfa836b..0000000000 --- a/examples/load_own_data.ipynb +++ /dev/null @@ -1,159 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Since STREAM is using the preprocessing from the octis core package, you must download some spacy specific utils.\n", - "\n", - "python -m spacy download en_core_web_sm before." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\torch\\_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n", - "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\dash\\_jupyter.py:31: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n", - " _dash_comm = Comm(target_name=\"dash\")\n" - ] - } - ], - "source": [ - "import sys\n", - "sys.path.append(\"..\")\n", - "from stream.utils import TMDataset\n", - "import pandas as pd\n", - "import numpy as np\n", - "\n", - "\n", - "# Simulating some example data\n", - "np.random.seed(0) # For reproducibility\n", - "\n", - "# Generate 1000 random strings of lengths between 1 and 5, containing letters 'A' to 'Z'\n", - "random_documents = [''.join(np.random.choice(list('ABCDEFGHIJKLMNOPQRSTUVWXYZ'), \n", - " np.random.randint(1, 6))) for _ in range(1000)]\n", - "\n", - "# Generate 1000 random labels from 1 to 4 as strings\n", - "random_labels = np.random.choice(['1', '2', '3', '4'], 1000)\n", - "\n", - "# Create DataFrame\n", - "my_data = pd.DataFrame({\"Documents\": random_documents, \"Labels\": random_labels})\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "pass the dataframe to the create_load_save_dataset function and specify the used columns. use label_column=None if no labels are available.\n", - "The dataset is preprocessed and saved and directly returned. If you want to use your dataset later, you can simply run dataset.fetch_dataset(your_dataset_path)." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Preprocessing documents: 100%|██████████| 1000/1000 [00:13<00:00, 75.92it/s]\n" - ] - } - ], - "source": [ - "dataset = TMDataset()\n", - "dataset = dataset.create_load_save_dataset(\n", - " data=my_data, \n", - " dataset_name=\"my_dataset_name\",\n", - " save_dir=\"my_dataset_save_directory\",\n", - " doc_column=\"Documents\",\n", - " label_column=\"Labels\"\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# To fecth the dataset, simply run" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = TMDataset()\n", - "dataset.fetch_dataset(name=\"my_dataset_name\", dataset_path=\"my_dataset_save_directory\")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2024-06-23 13:19:19.418\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m262\u001b[0m - \u001b[1m--- Training KmeansTM topic model ---\u001b[0m\n", - "\u001b[32m2024-06-23 13:19:19.419\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_prepare_embeddings\u001b[0m:\u001b[36m171\u001b[0m - \u001b[1m--- Creating paraphrase-MiniLM-L3-v2 document embeddings ---\u001b[0m\n", - "100%|██████████| 1000/1000 [00:06<00:00, 155.09it/s]\n", - "\u001b[32m2024-06-23 13:19:25.980\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_dim_reduction\u001b[0m:\u001b[36m194\u001b[0m - \u001b[1m--- Reducing dimensions ---\u001b[0m\n", - "\u001b[32m2024-06-23 13:19:35.562\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_clustering\u001b[0m:\u001b[36m214\u001b[0m - \u001b[1m--- Creating document cluster ---\u001b[0m\n", - "\u001b[32m2024-06-23 13:19:35.810\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m295\u001b[0m - \u001b[1m--- Training completed successfully. ---\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[['vb', 'ld', 'rd', 'ncb', 'hmaul', 'vjebr', 'vhk', 'vg', 'vfhy', 'vf'], ['ps', 'lot', 'mu', 'nose', 'psce', 'warm', 'fl', 'byvet', 'jp', 'id'], ['eq', 'et', 'eqzdh', 'eqtoi', 'ev', 'ejiu', 'eyvh', 'ezako', 'egum', 'egarx'], ['hv', 'hws', 'hxohw', 'hu', 'htnsy', 'htafl', 'hxr', 'hyyp', 'hb', 'hbbh'], ['pe', 'ki', 'pekr', 'krlri', 'kkacb', 'kmgd', 'kmr', 'kn', 'knsb', 'knvb'], ['jd', 'gj', 'ullfq', 'pfjw', 'wrgh', 'ws', 'ortdj', 'jvh', 'wwm', 'jjma'], ['zd', 'zzg', 'zmism', 'zya', 'ixtp', 'sqzh', 'wuchn', 'zanmg', 'zarh', 'zfnv'], ['ov', 'oa', 'oinjo', 'ocro', 'ooson', 'oou', 'oat', 'oqub', 'oygny', 'oy'], ['tkg', 'tlg', 'oflt', 'tgo', 'tm', 'jp', 'jon', 'jo', 'jmmt', 'jmafw'], ['xsk', 'xooqd', 'xqfca', 'xqhlk', 'xo', 'xtz', 'xn', 'xyy', 'xly', 'xzos'], ['si', 'ysd', 'ybf', 'yrtt', 'ys', 'yp', 'yn', 'yv', 'yydsp', 'yynx'], ['dy', 'dp', 'dsrbs', 'dmeb', 'dmeui', 'dmgpe', 'dgcf', 'dfmsa', 'ds', 'dn'], ['sb', 'en', 'ab', 'rf', 'bveff', 'bue', 'bma', 'ukkur', 'bh', 'bee'], ['qci', 'qcp', 'qcl', 'qcw', 'qcfuc', 'qcbwn', 'qc', 'qco', 'hrf', 'jqgvd'], ['vz', 'vmzk', 'gg', 'woz', 'kwza', 'ouvd', 'oszm', 'osza', 'wvvz', 'osao'], ['lf', 'lml', 'lrjaw', 'lxg', 'lxc', 'lx', 'lwaen', 'lvmpx', 'yfcml', 'lqd'], ['ib', 'og', 'uc', 'da', 'ua', 'dh', 'ih', 'th', 'ng', 'tarcs'], ['fb', 'mrqp', 'wqp', 'fsyq', 'whmqq', 'ecq', 'qifpf', 'fvo', 'fvq', 'wnxzq'], ['rk', 'hyzym', 'rgpcv', 'rhb', 'rhc', 'rhk', 'ftxyt', 'rkh', 'rkm', 'rpq'], ['mw', 'qld', 'qmfu', 'qksxu', 'km', 'qlbz', 'qysf', 'qlv', 'qm', 'qni']]\n" - ] - } - ], - "source": [ - "from stream.models import KmeansTM\n", - "# -> specify a existing folder path where to save the embeddings (or where to load the pre embedded dataset)\n", - "model = KmeansTM(embeddings_folder_path=\"../my_embedding_folder\") \n", - "# -> set the following arguments for num_topic optimization: KmeansTM(optim = True, optim_range = [5, 25])\n", - "model.fit(dataset) \n", - "topics = model.get_topics()\n", - "print(topics)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "stream_venv", - "language": "python", - "name": "stream_venv" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/stream_topic/__version__.py b/stream_topic/__version__.py index 7c5d2fe9fe..10b95664f1 100644 --- a/stream_topic/__version__.py +++ b/stream_topic/__version__.py @@ -1,4 +1,4 @@ """Version information.""" # The following line *must* be the last in the module, exactly as formatted: -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/stream_topic/metrics/TopwordEmbeddings.py b/stream_topic/metrics/TopwordEmbeddings.py index 5ce37bcbc5..7bbaf3d210 100644 --- a/stream_topic/metrics/TopwordEmbeddings.py +++ b/stream_topic/metrics/TopwordEmbeddings.py @@ -1,12 +1,14 @@ - import os import pickle import numpy as np from sentence_transformers import SentenceTransformer -from .constants import (EMBEDDING_PATH, PARAPHRASE_TRANSFORMER_MODEL, - SENTENCE_TRANSFORMER_MODEL) +from .constants import ( + EMBEDDING_PATH, + PARAPHRASE_TRANSFORMER_MODEL, + SENTENCE_TRANSFORMER_MODEL, +) class TopwordEmbeddings: @@ -28,13 +30,14 @@ class TopwordEmbeddings: """ def __init__( - self, - word_embedding_model: SentenceTransformer = SentenceTransformer( - PARAPHRASE_TRANSFORMER_MODEL), - cache_to_file: bool = False, - emb_filename: str = None, - emb_path: str = EMBEDDING_PATH, - create_new_file: bool = True + self, + word_embedding_model: SentenceTransformer = SentenceTransformer( + PARAPHRASE_TRANSFORMER_MODEL + ), + cache_to_file: bool = False, + emb_filename: str = None, + emb_path: str = EMBEDDING_PATH, + create_new_file: bool = True, ): """ Initialize the TopwordEmbeddings object. @@ -70,7 +73,8 @@ def _load_embedding_dict_from_disc(self): """ try: self.embedding_dict = pickle.load( - open(f"{self.emb_path}{self.emb_filename}.pickle", "rb")) + open(f"{self.emb_path}{self.emb_filename}.pickle", "rb") + ) except FileNotFoundError: self.embedding_dict = {} @@ -79,13 +83,12 @@ def _save_embedding_dict_to_disc(self): Save the embedding dictionary to the disk. """ with open(f"{self.emb_path}{self.emb_filename}.pickle", "wb") as handle: - pickle.dump(self.embedding_dict, handle, - protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(self.embedding_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) def embed_topwords( - self, - topwords: np.ndarray, - n_topwords_to_use: int = 10, + self, + topwords: np.ndarray, + n_topwords_to_use: int = 10, ) -> np.ndarray: """ Get the embeddings of the n_topwords topwords. @@ -111,8 +114,11 @@ def embed_topwords( topwords = topwords.reshape(-1, 1) assert np.issubdtype( - topwords.dtype, np.str_), "topwords should only contain strings." - assert topwords.shape[1] >= n_topwords_to_use, "n_topwords_to_use should be less than or equal to the number of words in each topic." + topwords.dtype, np.str_ + ), "topwords should only contain strings." + assert ( + topwords.shape[1] >= n_topwords_to_use + ), "n_topwords_to_use should be less than or equal to the number of words in each topic." # Get the top n_topwords words topwords = topwords[:, :n_topwords_to_use] diff --git a/stream_topic/metrics/coherence_metrics.py b/stream_topic/metrics/coherence_metrics.py index 73b62f9797..cc7b5f0e44 100644 --- a/stream_topic/metrics/coherence_metrics.py +++ b/stream_topic/metrics/coherence_metrics.py @@ -1,15 +1,11 @@ import re - import gensim -import nltk import numpy as np from nltk.corpus import stopwords -from octis.dataset.dataset import Dataset -from octis.evaluation_metrics.metrics import AbstractMetric from sentence_transformers import SentenceTransformer from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS - -from ._helper_funcs import cos_sim_pw, embed_corpus, embed_topic, update_corpus_dic_list +from .base import BaseMetric +from ._helper_funcs import cos_sim_pw from .constants import ( EMBEDDING_PATH, NLTK_STOPWORD_LANGUAGE, @@ -24,7 +20,7 @@ ) -class NPMI(AbstractMetric): +class NPMI(BaseMetric): """ A class for calculating Normalized Pointwise Mutual Information (NPMI) for topics. @@ -76,6 +72,26 @@ def __init__( files = self.dataset.get_corpus() self.files = [" ".join(words) for words in files] + def get_info(self): + """ + Get information about the metric. + + Returns + ------- + dict + Dictionary containing model information including metric name, + number of top words, number of intruders, embedding model name, + metric range and metric description. + """ + + info = { + "metric_name": "NPMI", + "n_words": self.n_words, + "description": "NPMI coherence", + } + + return info + def _create_vocab_preprocess(self, data, preprocess=5, process_data=False): """ Creates and preprocesses a vocabulary from the given data. @@ -287,7 +303,7 @@ def score_per_topic(self, topic_words, preprocess=5): return results -class Embedding_Coherence(AbstractMetric): +class Embedding_Coherence(BaseMetric): """ A metric class to calculate the coherence of topics based on word embeddings. It computes the average cosine similarity between all top words in each topic. diff --git a/stream_topic/metrics/diversity_metrics.py b/stream_topic/metrics/diversity_metrics.py index cda2dbb318..6f56f8482a 100644 --- a/stream_topic/metrics/diversity_metrics.py +++ b/stream_topic/metrics/diversity_metrics.py @@ -2,17 +2,13 @@ import nltk import numpy as np from nltk.corpus import stopwords -from octis.evaluation_metrics.metrics import AbstractMetric from sentence_transformers import SentenceTransformer from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS from sklearn.metrics.pairwise import cosine_similarity - +from .base import BaseMetric from ._helper_funcs import ( cos_sim_pw, - embed_corpus, embed_stopwords, - embed_topic, - update_corpus_dic_list, ) from .constants import ( EMBEDDING_PATH, @@ -29,7 +25,7 @@ ) -class Embedding_Topic_Diversity(AbstractMetric): +class Embedding_Topic_Diversity(BaseMetric): """ A metric class to calculate the diversity of topics based on word embeddings. It computes the mean cosine similarity of the mean vectors of the top words of all topics, providing @@ -204,7 +200,7 @@ def score_per_topic(self, topics, beta): return results -class Expressivity(AbstractMetric): +class Expressivity(BaseMetric): """ A metric class to calculate the expressivity of topics by measuring the distance between the mean vector of the top words in a topic and the mean vector of the embeddings of diff --git a/stream_topic/metrics/intruder_metrics.py b/stream_topic/metrics/intruder_metrics.py index 2520550c33..245a07d4b5 100644 --- a/stream_topic/metrics/intruder_metrics.py +++ b/stream_topic/metrics/intruder_metrics.py @@ -1,9 +1,6 @@ import numpy as np -from octis.evaluation_metrics.metrics import AbstractMetric from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity - -from ._helper_funcs import embed_corpus, embed_topic, update_corpus_dic_list from .base import BaseMetric from .constants import ( EMBEDDING_PATH, @@ -245,7 +242,7 @@ def score(self, topics, new_embeddings=True): return float(np.mean(list(self.score_per_topic(topics).values()))) -class INT(AbstractMetric): +class INT(BaseMetric): """ A metric class to calculate the Intruder Topic Metric (INT) for topics. This metric assesses the distinctiveness of topics by calculating the embedding intruder cosine similarity accuracy. It involves selecting intruder words @@ -490,7 +487,7 @@ def score(self, topics, new_embeddings=True): return float(np.mean(list(self.score_per_topic(topics).values()))) -class ISH(AbstractMetric): +class ISH(BaseMetric): """ A metric class to calculate the Intruder Similarity to Mean (ISH) for topics. This metric evaluates the distinctiveness of topics by measuring the average cosine similarity between the mean of the