Skip to content

Commit

Permalink
change default prepro steps
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Aug 7, 2024
1 parent f544961 commit 76fd691
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions stream_topic/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class TMDataset(Dataset):
"""
Topic Modeling Dataset containing methods to fetch and preprocess text data.
Topic Modeling Dataset containing methods to fetch and preprocess text data.
Parameters
----------
Expand Down Expand Up @@ -83,7 +83,8 @@ def __init__(self, name=None, language="en"):
self.available_datasets = self.get_dataset_list()
if name is not None and name not in self.available_datasets:
logger.error(
f"Dataset {name} not found. Available datasets: {self.available_datasets}")
f"Dataset {name} not found. Available datasets: {self.available_datasets}"
)
raise ValueError(
f"Dataset {name} not found. Available datasets: {self.available_datasets}"
)
Expand All @@ -107,8 +108,7 @@ def get_dataset_list(self):
list of str
List of available datasets.
"""
package_path = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
package_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
dataset_path = os.path.join(package_path, "preprocessed_datasets")
datasets = os.listdir(dataset_path)
return datasets
Expand All @@ -126,7 +126,7 @@ def default_preprocessing_steps(self):
"remove_special_chars": True,
"remove_accents": False,
"custom_stopwords": set(),
"detokenize": True,
"detokenize": False,
}

def fetch_dataset(self, name: str, dataset_path=None):
Expand All @@ -142,16 +142,19 @@ def fetch_dataset(self, name: str, dataset_path=None):
"""
if name not in self.available_datasets:
logger.error(
f"Dataset {name} not found. Available datasets: {self.available_datasets}")
f"Dataset {name} not found. Available datasets: {self.available_datasets}"
)
raise ValueError(
f"Dataset {name} not found. Available datasets: {self.available_datasets}"
)

if self.name is not None:
logger.info(
f'Dataset name already provided while instantiating the class: {self.name}')
f"Dataset name already provided while instantiating the class: {self.name}"
)
logger.info(
f'Overwriting the dataset name with the provided name in fetch_dataset: {name}')
f"Overwriting the dataset name with the provided name in fetch_dataset: {name}"
)
self.name = name
logger.info(f"Fetching dataset: {name}")
else:
Expand Down Expand Up @@ -180,8 +183,7 @@ def _load_data_to_dataframe(self):
"labels": self.get_labels(),
}
)
self.dataframe["text"] = [" ".join(words)
for words in self.dataframe["tokens"]]
self.dataframe["text"] = [" ".join(words) for words in self.dataframe["tokens"]]
self.texts = self.dataframe["text"].tolist()
self.labels = self.dataframe["labels"].tolist()

Expand All @@ -201,8 +203,7 @@ def get_package_dataset_path(self, name):
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
my_package_dir = os.path.dirname(script_dir)
dataset_path = os.path.join(
my_package_dir, "preprocessed_datasets", name)
dataset_path = os.path.join(my_package_dir, "preprocessed_datasets", name)
return dataset_path

def has_embeddings(self, embedding_model_name, path=None, file_name=None):
Expand Down Expand Up @@ -337,8 +338,7 @@ def get_package_embeddings_path(self, name):
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
my_package_dir = os.path.dirname(script_dir)
dataset_path = os.path.join(
my_package_dir, "pre_embedded_datasets", name)
dataset_path = os.path.join(my_package_dir, "pre_embedded_datasets", name)
return dataset_path

def create_load_save_dataset(
Expand Down Expand Up @@ -375,21 +375,18 @@ def create_load_save_dataset(
"""
if isinstance(data, pd.DataFrame):
if doc_column is None:
raise ValueError(
"doc_column must be specified for DataFrame input")
raise ValueError("doc_column must be specified for DataFrame input")
documents = [
self.clean_text(str(row[doc_column])) for _, row in data.iterrows()
]
labels = (
data[label_column].tolist() if label_column else [
None] * len(documents)
data[label_column].tolist() if label_column else [None] * len(documents)
)
elif isinstance(data, list):
documents = [self.clean_text(doc) for doc in data]
labels = [None] * len(documents)
else:
raise TypeError(
"data must be a pandas DataFrame or a list of documents")
raise TypeError("data must be a pandas DataFrame or a list of documents")

# Initialize preprocessor with kwargs
preprocessor = TextPreprocessor(**kwargs)
Expand All @@ -406,7 +403,7 @@ def create_load_save_dataset(

# Save the dataset to Parquet format
if not os.path.exists(save_dir):
logger.info(f'Dataset save directory does not exist: {save_dir}')
logger.info(f"Dataset save directory does not exist: {save_dir}")
logger.info(f"Creating directory: {save_dir}")
os.makedirs(save_dir)

Expand All @@ -431,7 +428,8 @@ def create_load_save_dataset(

self.available_datasets.append(dataset_name)
logger.info(
f'Dataset name appended to avaliable datasets list: {self.available_datasets}')
f"Dataset name appended to avaliable datasets list: {self.available_datasets}"
)
# return preprocessor

def preprocess(self, model_type=None, custom_stopwords=None, **preprocessing_steps):
Expand Down Expand Up @@ -459,8 +457,7 @@ def preprocess(self, model_type=None, custom_stopwords=None, **preprocessing_ste
`texts` attribute and updated in the `dataframe["text"]` column.
"""
if model_type:
preprocessing_steps = load_model_preprocessing_steps(
model_type)
preprocessing_steps = load_model_preprocessing_steps(model_type)
previous_steps = self.preprocessing_steps

# Filter out steps that have already been applied
Expand Down Expand Up @@ -503,8 +500,7 @@ def preprocess(self, model_type=None, custom_stopwords=None, **preprocessing_ste
}
)
except Exception as e:
raise RuntimeError(
f"Error in dataset preprocessing: {e}") from e
raise RuntimeError(f"Error in dataset preprocessing: {e}") from e
self.update_preprocessing_steps(**filtered_steps)

def update_preprocessing_steps(self, **preprocessing_steps):
Expand Down Expand Up @@ -551,8 +547,7 @@ def get_info(self, dataset_path=None):

info_path = os.path.join(dataset_path, f"{self.name}_info.pkl")
if not os.path.exists(info_path):
raise FileNotFoundError(
f"Dataset info file {info_path} does not exist.")
raise FileNotFoundError(f"Dataset info file {info_path} does not exist.")

with open(info_path, "rb") as info_file:
dataset_info = pickle.load(info_file)
Expand Down Expand Up @@ -646,8 +641,7 @@ def load_custom_dataset_from_folder(self, dataset_path):
}
)

self.dataframe["tokens"] = self.dataframe["text"].apply(
lambda x: x.split())
self.dataframe["tokens"] = self.dataframe["text"].apply(lambda x: x.split())
self.texts = self.dataframe["text"].tolist()
self.labels = self.dataframe["labels"].tolist()

Expand Down Expand Up @@ -743,8 +737,7 @@ def get_bow(self, **kwargs):
"""
corpus = [" ".join(tokens) for tokens in self.get_corpus()]
vectorizer = CountVectorizer(**kwargs)
self.bow = vectorizer.fit_transform(
corpus).toarray().astype(np.float32)
self.bow = vectorizer.fit_transform(corpus).toarray().astype(np.float32)
return self.bow, vectorizer.get_feature_names_out()

def get_tfidf(self, **kwargs):
Expand Down Expand Up @@ -783,8 +776,10 @@ def has_word_embeddings(self, model_name):
True if word embeddings are available, False otherwise.
"""
return self.has_embeddings(model_name, "word_embeddings")

def save_word_embeddings(self, word_embeddings, model_name, path=None, file_name=None):

def save_word_embeddings(
self, word_embeddings, model_name, path=None, file_name=None
):
"""
Save word embeddings for the dataset.
Expand All @@ -802,7 +797,7 @@ def save_word_embeddings(self, word_embeddings, model_name, path=None, file_nam
file_name=file_name,
)

def get_word_embeddings(self, model_name="glove-wiki-gigaword-100", vocab = None):
def get_word_embeddings(self, model_name="glove-wiki-gigaword-100", vocab=None):
"""
Get the word embeddings for the vocabulary using a pre-trained model.
Expand Down Expand Up @@ -840,11 +835,15 @@ def get_word_embeddings(self, model_name="glove-wiki-gigaword-100", vocab = None
if model_name == "paraphrase-MiniLM-L3-v2":
model = SentenceTransformer(model_name)
vocabulary = list(vocabulary)
embeddings = model.encode(vocabulary, convert_to_tensor=True, show_progress_bar=True)
embeddings = model.encode(
vocabulary, convert_to_tensor=True, show_progress_bar=True
)

embeddings = {word: embeddings[i] for i, word in enumerate(vocabulary)}

assert len(embeddings) == len(vocabulary), "Embeddings and vocabulary length mismatch"
assert len(embeddings) == len(
vocabulary
), "Embeddings and vocabulary length mismatch"

return embeddings

Expand Down Expand Up @@ -874,7 +873,6 @@ def load_dataset_from_parquet(self, load_path):
if not os.path.exists(load_path):
raise FileNotFoundError(f"File {load_path} does not exist.")
self.dataframe = pd.read_parquet(load_path)
self.dataframe["tokens"] = self.dataframe["text"].apply(
lambda x: x.split())
self.dataframe["tokens"] = self.dataframe["text"].apply(lambda x: x.split())
self.texts = self.dataframe["text"].tolist()
self.labels = self.dataframe["labels"].tolist()

0 comments on commit 76fd691

Please sign in to comment.