diff --git a/README.md b/README.md
index f5cc755d99..d4b06d585a 100644
--- a/README.md
+++ b/README.md
@@ -2,16 +2,16 @@
-[![PyPI](tbd)](tbd)
-![PyPI - Downloads](tbd)
-[![docs build](https://readthedocs.org/projects/stream/badge/?version=latest)](tbd)
-[![docs](https://img.shields.io/badge/docs-latest-blue)](tbd)
+[![PyPI](https://img.shields.io/pypi/v/stream_topic)](https://pypi.org/project/stream_topic)
+![PyPI - Downloads](https://img.shields.io/pypi/dm/stream_topic)
+[![docs build](https://readthedocs.org/projects/stream_topic/badge/?version=latest)](https://stream_topic.readthedocs.io/en/latest/?badge=latest)
+[![docs](https://img.shields.io/badge/docs-latest-blue)](https://stream-topic.readthedocs.io/en/latest/index.html)
[![open issues](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/AnFreTh/STREAM/issues)
-[📘Documentation](tbd) |
-[🛠️Installation](#installation) |
-[Models](#available-models) |
+[📘Documentation](https://stream-topic.readthedocs.io/en/latest/index.html) |
+[🛠️Installation](https://stream-topic.readthedocs.io/en/latest/installation.html) |
+[Models](https://stream-topic.readthedocs.io/en/latest/api/models/index.html) |
[🤔Report Issues](https://github.com/AnFreTh/STREAM/issues)
diff --git a/docs/conf.py b/docs/conf.py
index 2cc8feb950..ba90d793ba 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -4,6 +4,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import os
+
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import pathlib
@@ -11,8 +12,7 @@
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../"))
-sys.path.insert(1, os.path.dirname(
- os.path.abspath("../")) + os.sep + "stream")
+sys.path.insert(1, os.path.dirname(os.path.abspath("../")) + os.sep + "stream")
project = "stream_topic"
author = "Anton Frederik Thielmann"
@@ -37,7 +37,7 @@
"sphinx_copybutton",
"sphinx_togglebutton",
"nbsphinx",
- 'myst_nb',
+ "myst_nb",
"IPython.sphinxext.ipython_console_highlighting",
"IPython.sphinxext.ipython_directive",
# "myst_parser",
@@ -64,7 +64,6 @@
"gensim",
"nltk",
"langdetect",
- "octis",
"loguru",
"scipy",
"community",
@@ -75,9 +74,7 @@
"umap_learn",
"dash",
"optuna",
- "optuna-integration"
- ""
-
+ "optuna-integration" "",
]
# Add any paths that contain templates here, relative to this directory.
@@ -123,8 +120,8 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_book_theme"
-html_static_path = ['_static']
-html_css_files = ['css/stream_theme.css']
+html_static_path = ["_static"]
+html_css_files = ["css/stream_theme.css"]
# html_css_files = ['custom.css']
# html_js_files = ['custom.js']
@@ -161,12 +158,12 @@
# -- Options for myst ----------------------------------------------
# uncomment line below to avoid running notebooks during development
-nb_execution_mode = 'off'
+nb_execution_mode = "off"
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
# files that will not be executed.
-myst_enable_extensions = ['dollarmath']
+myst_enable_extensions = ["dollarmath"]
# raise exceptions on execution so CI can catch errors
nb_execution_allow_errors = False
@@ -178,4 +175,6 @@
# other MyST extensions
]
-mathjax_path = "https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"
+mathjax_path = (
+ "https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"
+)
diff --git a/examples/KmeansTM.ipynb b/examples/KmeansTM.ipynb
deleted file mode 100644
index ebd0e293e4..0000000000
--- a/examples/KmeansTM.ipynb
+++ /dev/null
@@ -1,302 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Import and load dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "If you have not installed the package and work inside the repository, use the following two lines of code to make the example word:\n",
- "\n",
- "import sys\n",
- "sys.path.append(\"..\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\torch\\_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
- " return self.fget.__get__(instance, owner)()\n",
- "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\dash\\_jupyter.py:31: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
- " _dash_comm = Comm(target_name=\"dash\")\n"
- ]
- }
- ],
- "source": [
- "from stream.utils import TMDataset\n",
- "from stream.models import KmeansTM\n",
- "\n",
- "dataset = TMDataset()\n",
- "dataset.fetch_dataset(\"20NewsGroup\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'en'"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dataset.language"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Preprocessing documents: 0%| | 0/2500 [00:00, ?it/s]"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Preprocessing documents: 100%|██████████| 2500/2500 [00:39<00:00, 64.01it/s]\n"
- ]
- }
- ],
- "source": [
- "dataset.preprocess(model_type=\"KmeansTM\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['I', 'was', 'wondering', 'if', 'anyone', 'out', 'there', 'could', 'enlighten', 'me', 'on', 'this', 'car', 'I', 'saw', 'the', 'other', 'day.', 'It', 'was', 'a', '2door', 'sports', 'car,', 'looked', 'to', 'be', 'from', 'the', 'late', '60s/', 'early', '70s.', 'It', 'was', 'called', 'a', 'Bricklin.', 'The', 'doors', 'were', 'really', 'small.', 'In', 'addition,', 'the', 'front', 'bumper', 'was', 'separate', 'from', 'the', 'rest', 'of', 'the', 'body.', 'This', 'is', 'all', 'I', 'know.', 'If', 'anyone', 'can', 'tellme', 'a', 'model', 'name,', 'engine', 'specs,', 'years', 'of', 'production,', 'where', 'this', 'car', 'is', 'made,', 'history,', 'or', 'whatever', 'info', 'you', 'have', 'on', 'this', 'funky', 'looking', 'car,', 'please', 'email.']\n"
- ]
- }
- ],
- "source": [
- "print(dataset.dataframe.iloc[0][\"tokens\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "wondering anyone could enlighten car saw day door car looked late early called bricklin door really small addition front bumper separate rest body know anyone tellme model name engine production car made history whatever info funky looking car please email\n"
- ]
- }
- ],
- "source": [
- "print(dataset.dataframe.iloc[0][\"text\"])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Train the model\n",
- "If embeddings for the model have been created before, they will not be created again for faster computation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m2024-06-23 11:00:34.330\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m262\u001b[0m - \u001b[1m--- Training KmeansTM topic model ---\u001b[0m\n",
- "\u001b[32m2024-06-23 11:00:34.338\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_prepare_embeddings\u001b[0m:\u001b[36m161\u001b[0m - \u001b[1m--- Loading pre-computed paraphrase-MiniLM-L3-v2 embeddings ---\u001b[0m\n",
- "\u001b[32m2024-06-23 11:00:34.342\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_dim_reduction\u001b[0m:\u001b[36m194\u001b[0m - \u001b[1m--- Reducing dimensions ---\u001b[0m\n",
- "\u001b[32m2024-06-23 11:00:46.663\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_clustering\u001b[0m:\u001b[36m214\u001b[0m - \u001b[1m--- Creating document cluster ---\u001b[0m\n",
- "\u001b[32m2024-06-23 11:00:47.279\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m295\u001b[0m - \u001b[1m--- Training completed successfully. ---\u001b[0m\n"
- ]
- }
- ],
- "source": [
- "model = KmeansTM() \n",
- "model.fit(dataset) \n",
- "topics = model.get_topics()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[['printer', 'font', 'truetype', 'laser', 'postscript', 'window', 'color', 'print', 'image', 'atm'], ['gun', 'weapon', 'fbi', 'koresh', 'crime', 'firearm', 'child', 'atf', 'batf', 'kid'], ['max', 'b8f', 'a86', '145', '0t', 'pl', '1d9', '1t', 'giz', '2di'], ['game', 'team', 'player', 'play', 'hockey', 'goal', 'season', 'playoff', 'period', 'leaf'], ['health', 'patient', 'tobacco', 'doctor', 'food', 'medical', 'disease', 'msg', 'treatment', 'infection'], ['car', 'bike', 'engine', 'mile', 'speed', 'mph', 'mustang', 'wheel', 'tire', 'diesel'], ['moral', 'objective', 'morality', 'homosexual', 'sex', 'gay', 'men', 'murder', 'sexual', 'immoral'], ['circuit', 'mouse', 'input', 'digital', 'output', 'audio', 'signal', 'monitor', 'screen', 'amp'], ['space', 'nasa', 'shuttle', 'mission', 'orbit', 'satellite', 'launch', 'earth', 'astronaut', 'solar'], ['file', 'window', 'widget', 'image', 'program', 'motif', 'available', 'server', 'set', 'version'], ['key', 'encryption', 'clipper', 'chip', 'government', 'phone', 'security', 'escrow', 'privacy', 'secure'], ['government', 'clinton', 'constitution', 'president', 'right', 'law', 'militia', 'tax', 'state', 'billion'], ['gainey', 'post', 'bob', 'david', 'smiley', 'sternlight', 'beer', 'think', 'article', 'gilmour'], ['contact', 'machine', 'type', 'version', 'ftp', 'comment', 'address', 'anonymous', 'email', 'number'], ['god', 'jesus', 'christian', 'faith', 'bible', 'truth', 'church', 'sin', 'argument', 'christ'], ['armenian', 'turkish', 'muslim', 'israel', 'jew', 'genocide', 'arab', 'israeli', 'war', 'said'], ['drive', 'scsi', 'card', 'driver', 'disk', 'tape', 'controller', 'memory', 'bus', 'window'], ['team', 'player', 'brave', 'game', 'year', 'baseball', 'hit', 'morris', 'season', 'better'], ['dog', 'neighbor', 'detector', 'stove', 'radar', 'neighborhood', 'phoenix', 'bike', 'house', 'beast'], ['shipping', 'excellent', 'price', 'tnde', 'offer', 'uccxkvb', 'geoffrey', 'cover', 'missing', 'good']]\n"
- ]
- }
- ],
- "source": [
- "print(topics)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Evluate your model. \n",
- "Use all metrics available either in octis or the ExpandedTM metrics, ISIM, INT, Expressivity, Embedding_Coherence, Embedding_Topic_Diversity and classical NPMI"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "from stream.metrics import NPMI\n",
- "metric = NPMI(dataset)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "-0.1961\n"
- ]
- }
- ],
- "source": [
- "topics = model.get_topics()\n",
- "score = metric.score(topics)\n",
- "print(score)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Visualize your fit model\n",
- "Use a port that is not already in use. default is 8050"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "from stream.visuals import visualize_topic_model, visualize_topics\n",
- "visualize_topic_model(model, dataset=dataset, port=8051)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "visualize_topics(model, port=8052)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "STREAM_venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.13"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/downstream_task.ipynb b/examples/downstream_task.ipynb
deleted file mode 100644
index ac52f877aa..0000000000
--- a/examples/downstream_task.ipynb
+++ /dev/null
@@ -1,152 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Import and load dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from stream.data_utils import TMDataset\n",
- "from stream.models import CEDC\n",
- "\n",
- "dataset = TMDataset()\n",
- "dataset.fetch_dataset(\"Spotify_random\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Train the model\n",
- "If embeddings for the model have been created before, they will not be created again for faster computation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "model = CEDC(num_topics=5) # Create model\n",
- "model_output = model.train_model(dataset) "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Visualize topics"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from stream.visuals import visualize_topic_model\n",
- "visualize_topic_model(model, port=8053)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Fit Downstream Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from pytorch_lightning import Trainer\n",
- "from stream.NAM import DownstreamModel\n",
- "\n",
- "# Instantiate the DownstreamModel\n",
- "downstream_model = DownstreamModel(\n",
- " trained_topic_model=model,\n",
- " target_column='popularity', # Target variable\n",
- " task='regression', # or 'classification'\n",
- " dataset=dataset, \n",
- " batch_size=512,\n",
- " lr=0.0005\n",
- ")\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Train the downstream model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Use PyTorch Lightning's Trainer to train and validate the model\n",
- "trainer = Trainer(max_epochs=50)\n",
- "trainer.fit(downstream_model)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "trainer.validate()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Visualize Feature and Topic contributions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from stream.visuals import plot_downstream_model\n",
- "\n",
- "plot_downstream_model(downstream_model)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "ExpandedTM_venv",
- "language": "python",
- "name": "expandedtm_venv"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.18"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/examples/load_own_data.ipynb b/examples/load_own_data.ipynb
deleted file mode 100644
index c3bdfa836b..0000000000
--- a/examples/load_own_data.ipynb
+++ /dev/null
@@ -1,159 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Since STREAM is using the preprocessing from the octis core package, you must download some spacy specific utils.\n",
- "\n",
- "python -m spacy download en_core_web_sm before."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\torch\\_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
- " return self.fget.__get__(instance, owner)()\n",
- "c:\\Users\\anton\\anaconda3\\envs\\STREAM_venv\\lib\\site-packages\\dash\\_jupyter.py:31: DeprecationWarning: The `ipykernel.comm.Comm` class has been deprecated. Please use the `comm` module instead.For creating comms, use the function `from comm import create_comm`.\n",
- " _dash_comm = Comm(target_name=\"dash\")\n"
- ]
- }
- ],
- "source": [
- "import sys\n",
- "sys.path.append(\"..\")\n",
- "from stream.utils import TMDataset\n",
- "import pandas as pd\n",
- "import numpy as np\n",
- "\n",
- "\n",
- "# Simulating some example data\n",
- "np.random.seed(0) # For reproducibility\n",
- "\n",
- "# Generate 1000 random strings of lengths between 1 and 5, containing letters 'A' to 'Z'\n",
- "random_documents = [''.join(np.random.choice(list('ABCDEFGHIJKLMNOPQRSTUVWXYZ'), \n",
- " np.random.randint(1, 6))) for _ in range(1000)]\n",
- "\n",
- "# Generate 1000 random labels from 1 to 4 as strings\n",
- "random_labels = np.random.choice(['1', '2', '3', '4'], 1000)\n",
- "\n",
- "# Create DataFrame\n",
- "my_data = pd.DataFrame({\"Documents\": random_documents, \"Labels\": random_labels})\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "pass the dataframe to the create_load_save_dataset function and specify the used columns. use label_column=None if no labels are available.\n",
- "The dataset is preprocessed and saved and directly returned. If you want to use your dataset later, you can simply run dataset.fetch_dataset(your_dataset_path)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Preprocessing documents: 100%|██████████| 1000/1000 [00:13<00:00, 75.92it/s]\n"
- ]
- }
- ],
- "source": [
- "dataset = TMDataset()\n",
- "dataset = dataset.create_load_save_dataset(\n",
- " data=my_data, \n",
- " dataset_name=\"my_dataset_name\",\n",
- " save_dir=\"my_dataset_save_directory\",\n",
- " doc_column=\"Documents\",\n",
- " label_column=\"Labels\"\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# To fecth the dataset, simply run"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "dataset = TMDataset()\n",
- "dataset.fetch_dataset(name=\"my_dataset_name\", dataset_path=\"my_dataset_save_directory\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[32m2024-06-23 13:19:19.418\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m262\u001b[0m - \u001b[1m--- Training KmeansTM topic model ---\u001b[0m\n",
- "\u001b[32m2024-06-23 13:19:19.419\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_prepare_embeddings\u001b[0m:\u001b[36m171\u001b[0m - \u001b[1m--- Creating paraphrase-MiniLM-L3-v2 document embeddings ---\u001b[0m\n",
- "100%|██████████| 1000/1000 [00:06<00:00, 155.09it/s]\n",
- "\u001b[32m2024-06-23 13:19:25.980\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_dim_reduction\u001b[0m:\u001b[36m194\u001b[0m - \u001b[1m--- Reducing dimensions ---\u001b[0m\n",
- "\u001b[32m2024-06-23 13:19:35.562\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36m_clustering\u001b[0m:\u001b[36m214\u001b[0m - \u001b[1m--- Creating document cluster ---\u001b[0m\n",
- "\u001b[32m2024-06-23 13:19:35.810\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mstream.models.KmeansTM\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m295\u001b[0m - \u001b[1m--- Training completed successfully. ---\u001b[0m\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[['vb', 'ld', 'rd', 'ncb', 'hmaul', 'vjebr', 'vhk', 'vg', 'vfhy', 'vf'], ['ps', 'lot', 'mu', 'nose', 'psce', 'warm', 'fl', 'byvet', 'jp', 'id'], ['eq', 'et', 'eqzdh', 'eqtoi', 'ev', 'ejiu', 'eyvh', 'ezako', 'egum', 'egarx'], ['hv', 'hws', 'hxohw', 'hu', 'htnsy', 'htafl', 'hxr', 'hyyp', 'hb', 'hbbh'], ['pe', 'ki', 'pekr', 'krlri', 'kkacb', 'kmgd', 'kmr', 'kn', 'knsb', 'knvb'], ['jd', 'gj', 'ullfq', 'pfjw', 'wrgh', 'ws', 'ortdj', 'jvh', 'wwm', 'jjma'], ['zd', 'zzg', 'zmism', 'zya', 'ixtp', 'sqzh', 'wuchn', 'zanmg', 'zarh', 'zfnv'], ['ov', 'oa', 'oinjo', 'ocro', 'ooson', 'oou', 'oat', 'oqub', 'oygny', 'oy'], ['tkg', 'tlg', 'oflt', 'tgo', 'tm', 'jp', 'jon', 'jo', 'jmmt', 'jmafw'], ['xsk', 'xooqd', 'xqfca', 'xqhlk', 'xo', 'xtz', 'xn', 'xyy', 'xly', 'xzos'], ['si', 'ysd', 'ybf', 'yrtt', 'ys', 'yp', 'yn', 'yv', 'yydsp', 'yynx'], ['dy', 'dp', 'dsrbs', 'dmeb', 'dmeui', 'dmgpe', 'dgcf', 'dfmsa', 'ds', 'dn'], ['sb', 'en', 'ab', 'rf', 'bveff', 'bue', 'bma', 'ukkur', 'bh', 'bee'], ['qci', 'qcp', 'qcl', 'qcw', 'qcfuc', 'qcbwn', 'qc', 'qco', 'hrf', 'jqgvd'], ['vz', 'vmzk', 'gg', 'woz', 'kwza', 'ouvd', 'oszm', 'osza', 'wvvz', 'osao'], ['lf', 'lml', 'lrjaw', 'lxg', 'lxc', 'lx', 'lwaen', 'lvmpx', 'yfcml', 'lqd'], ['ib', 'og', 'uc', 'da', 'ua', 'dh', 'ih', 'th', 'ng', 'tarcs'], ['fb', 'mrqp', 'wqp', 'fsyq', 'whmqq', 'ecq', 'qifpf', 'fvo', 'fvq', 'wnxzq'], ['rk', 'hyzym', 'rgpcv', 'rhb', 'rhc', 'rhk', 'ftxyt', 'rkh', 'rkm', 'rpq'], ['mw', 'qld', 'qmfu', 'qksxu', 'km', 'qlbz', 'qysf', 'qlv', 'qm', 'qni']]\n"
- ]
- }
- ],
- "source": [
- "from stream.models import KmeansTM\n",
- "# -> specify a existing folder path where to save the embeddings (or where to load the pre embedded dataset)\n",
- "model = KmeansTM(embeddings_folder_path=\"../my_embedding_folder\") \n",
- "# -> set the following arguments for num_topic optimization: KmeansTM(optim = True, optim_range = [5, 25])\n",
- "model.fit(dataset) \n",
- "topics = model.get_topics()\n",
- "print(topics)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "stream_venv",
- "language": "python",
- "name": "stream_venv"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.13"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/stream_topic/__version__.py b/stream_topic/__version__.py
index 7c5d2fe9fe..10b95664f1 100644
--- a/stream_topic/__version__.py
+++ b/stream_topic/__version__.py
@@ -1,4 +1,4 @@
"""Version information."""
# The following line *must* be the last in the module, exactly as formatted:
-__version__ = "0.1.0"
+__version__ = "0.1.1"
diff --git a/stream_topic/metrics/TopwordEmbeddings.py b/stream_topic/metrics/TopwordEmbeddings.py
index 5ce37bcbc5..7bbaf3d210 100644
--- a/stream_topic/metrics/TopwordEmbeddings.py
+++ b/stream_topic/metrics/TopwordEmbeddings.py
@@ -1,12 +1,14 @@
-
import os
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
-from .constants import (EMBEDDING_PATH, PARAPHRASE_TRANSFORMER_MODEL,
- SENTENCE_TRANSFORMER_MODEL)
+from .constants import (
+ EMBEDDING_PATH,
+ PARAPHRASE_TRANSFORMER_MODEL,
+ SENTENCE_TRANSFORMER_MODEL,
+)
class TopwordEmbeddings:
@@ -28,13 +30,14 @@ class TopwordEmbeddings:
"""
def __init__(
- self,
- word_embedding_model: SentenceTransformer = SentenceTransformer(
- PARAPHRASE_TRANSFORMER_MODEL),
- cache_to_file: bool = False,
- emb_filename: str = None,
- emb_path: str = EMBEDDING_PATH,
- create_new_file: bool = True
+ self,
+ word_embedding_model: SentenceTransformer = SentenceTransformer(
+ PARAPHRASE_TRANSFORMER_MODEL
+ ),
+ cache_to_file: bool = False,
+ emb_filename: str = None,
+ emb_path: str = EMBEDDING_PATH,
+ create_new_file: bool = True,
):
"""
Initialize the TopwordEmbeddings object.
@@ -70,7 +73,8 @@ def _load_embedding_dict_from_disc(self):
"""
try:
self.embedding_dict = pickle.load(
- open(f"{self.emb_path}{self.emb_filename}.pickle", "rb"))
+ open(f"{self.emb_path}{self.emb_filename}.pickle", "rb")
+ )
except FileNotFoundError:
self.embedding_dict = {}
@@ -79,13 +83,12 @@ def _save_embedding_dict_to_disc(self):
Save the embedding dictionary to the disk.
"""
with open(f"{self.emb_path}{self.emb_filename}.pickle", "wb") as handle:
- pickle.dump(self.embedding_dict, handle,
- protocol=pickle.HIGHEST_PROTOCOL)
+ pickle.dump(self.embedding_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
def embed_topwords(
- self,
- topwords: np.ndarray,
- n_topwords_to_use: int = 10,
+ self,
+ topwords: np.ndarray,
+ n_topwords_to_use: int = 10,
) -> np.ndarray:
"""
Get the embeddings of the n_topwords topwords.
@@ -111,8 +114,11 @@ def embed_topwords(
topwords = topwords.reshape(-1, 1)
assert np.issubdtype(
- topwords.dtype, np.str_), "topwords should only contain strings."
- assert topwords.shape[1] >= n_topwords_to_use, "n_topwords_to_use should be less than or equal to the number of words in each topic."
+ topwords.dtype, np.str_
+ ), "topwords should only contain strings."
+ assert (
+ topwords.shape[1] >= n_topwords_to_use
+ ), "n_topwords_to_use should be less than or equal to the number of words in each topic."
# Get the top n_topwords words
topwords = topwords[:, :n_topwords_to_use]
diff --git a/stream_topic/metrics/coherence_metrics.py b/stream_topic/metrics/coherence_metrics.py
index 73b62f9797..cc7b5f0e44 100644
--- a/stream_topic/metrics/coherence_metrics.py
+++ b/stream_topic/metrics/coherence_metrics.py
@@ -1,15 +1,11 @@
import re
-
import gensim
-import nltk
import numpy as np
from nltk.corpus import stopwords
-from octis.dataset.dataset import Dataset
-from octis.evaluation_metrics.metrics import AbstractMetric
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
-
-from ._helper_funcs import cos_sim_pw, embed_corpus, embed_topic, update_corpus_dic_list
+from .base import BaseMetric
+from ._helper_funcs import cos_sim_pw
from .constants import (
EMBEDDING_PATH,
NLTK_STOPWORD_LANGUAGE,
@@ -24,7 +20,7 @@
)
-class NPMI(AbstractMetric):
+class NPMI(BaseMetric):
"""
A class for calculating Normalized Pointwise Mutual Information (NPMI) for topics.
@@ -76,6 +72,26 @@ def __init__(
files = self.dataset.get_corpus()
self.files = [" ".join(words) for words in files]
+ def get_info(self):
+ """
+ Get information about the metric.
+
+ Returns
+ -------
+ dict
+ Dictionary containing model information including metric name,
+ number of top words, number of intruders, embedding model name,
+ metric range and metric description.
+ """
+
+ info = {
+ "metric_name": "NPMI",
+ "n_words": self.n_words,
+ "description": "NPMI coherence",
+ }
+
+ return info
+
def _create_vocab_preprocess(self, data, preprocess=5, process_data=False):
"""
Creates and preprocesses a vocabulary from the given data.
@@ -287,7 +303,7 @@ def score_per_topic(self, topic_words, preprocess=5):
return results
-class Embedding_Coherence(AbstractMetric):
+class Embedding_Coherence(BaseMetric):
"""
A metric class to calculate the coherence of topics based on word embeddings. It computes
the average cosine similarity between all top words in each topic.
diff --git a/stream_topic/metrics/diversity_metrics.py b/stream_topic/metrics/diversity_metrics.py
index cda2dbb318..6f56f8482a 100644
--- a/stream_topic/metrics/diversity_metrics.py
+++ b/stream_topic/metrics/diversity_metrics.py
@@ -2,17 +2,13 @@
import nltk
import numpy as np
from nltk.corpus import stopwords
-from octis.evaluation_metrics.metrics import AbstractMetric
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
from sklearn.metrics.pairwise import cosine_similarity
-
+from .base import BaseMetric
from ._helper_funcs import (
cos_sim_pw,
- embed_corpus,
embed_stopwords,
- embed_topic,
- update_corpus_dic_list,
)
from .constants import (
EMBEDDING_PATH,
@@ -29,7 +25,7 @@
)
-class Embedding_Topic_Diversity(AbstractMetric):
+class Embedding_Topic_Diversity(BaseMetric):
"""
A metric class to calculate the diversity of topics based on word embeddings. It computes
the mean cosine similarity of the mean vectors of the top words of all topics, providing
@@ -204,7 +200,7 @@ def score_per_topic(self, topics, beta):
return results
-class Expressivity(AbstractMetric):
+class Expressivity(BaseMetric):
"""
A metric class to calculate the expressivity of topics by measuring the distance between
the mean vector of the top words in a topic and the mean vector of the embeddings of
diff --git a/stream_topic/metrics/intruder_metrics.py b/stream_topic/metrics/intruder_metrics.py
index 2520550c33..245a07d4b5 100644
--- a/stream_topic/metrics/intruder_metrics.py
+++ b/stream_topic/metrics/intruder_metrics.py
@@ -1,9 +1,6 @@
import numpy as np
-from octis.evaluation_metrics.metrics import AbstractMetric
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
-
-from ._helper_funcs import embed_corpus, embed_topic, update_corpus_dic_list
from .base import BaseMetric
from .constants import (
EMBEDDING_PATH,
@@ -245,7 +242,7 @@ def score(self, topics, new_embeddings=True):
return float(np.mean(list(self.score_per_topic(topics).values())))
-class INT(AbstractMetric):
+class INT(BaseMetric):
"""
A metric class to calculate the Intruder Topic Metric (INT) for topics. This metric assesses the distinctiveness
of topics by calculating the embedding intruder cosine similarity accuracy. It involves selecting intruder words
@@ -490,7 +487,7 @@ def score(self, topics, new_embeddings=True):
return float(np.mean(list(self.score_per_topic(topics).values())))
-class ISH(AbstractMetric):
+class ISH(BaseMetric):
"""
A metric class to calculate the Intruder Similarity to Mean (ISH) for topics. This metric evaluates
the distinctiveness of topics by measuring the average cosine similarity between the mean of the