Skip to content

Commit

Permalink
Merge pull request #2939 from RaRe-Technologies/2vec_saveload_fixes
Browse files Browse the repository at this point in the history
[MRG] *2Vec SaveLoad improvements
  • Loading branch information
piskvorky authored Sep 24, 2020
2 parents 08a61e5 + da8847a commit c6c24ea
Show file tree
Hide file tree
Showing 21 changed files with 515 additions and 363 deletions.
1 change: 1 addition & 0 deletions gensim/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
By default, this subdirectory is ~/gensim-data.
"""

from __future__ import absolute_import
import argparse
import os
Expand Down
2 changes: 1 addition & 1 deletion gensim/models/_fasttext_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def _dict_save(fout, model, encoding):
# prunedidx_size_=-1, -1 value denotes no prunning index (prunning is only supported in supervised mode)
fout.write(np.int64(-1))

for word in model.wv.index2word:
for word in model.wv.index_to_key:
word_count = model.wv.get_vecattr(word, 'count')
fout.write(word.encode(encoding))
fout.write(_END_OF_WORD_MARKER)
Expand Down
9 changes: 6 additions & 3 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Internal functions for pipelines.
"""

import logging
import multiprocessing as mp
from collections import namedtuple
Expand All @@ -33,9 +34,11 @@

from gensim import interfaces, matutils
from gensim import utils
from gensim.topic_coherence import (segmentation, probability_estimation,
direct_confirmation_measure, indirect_confirmation_measure,
aggregation)
from gensim.topic_coherence import (
segmentation, probability_estimation,
direct_confirmation_measure, indirect_confirmation_measure,
aggregation,
)
from gensim.topic_coherence.probability_estimation import unique_ids_from_segments

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def load(cls, *args, **kwargs):
except AttributeError as ae:
logger.error(
"Model load error. Was model saved using code from an older Gensim Version? "
"Try loading older model using gensim-3.8.1, then re-saving, to restore "
"Try loading older model using gensim-3.8.3, then re-saving, to restore "
"compatibility with current code.")
raise ae

Expand Down
89 changes: 58 additions & 31 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,6 @@ def save(self, *args, **kwargs):
Load :class:`~gensim.models.fasttext.FastText` model.
"""
kwargs['ignore'] = kwargs.get('ignore', []) + ['buckets_word', ]
super(FastText, self).save(*args, **kwargs)

@classmethod
Expand All @@ -850,25 +849,15 @@ def load(cls, *args, **kwargs):
Save :class:`~gensim.models.fasttext.FastText` model.
"""
model = super(FastText, cls).load(*args, rethrow=True, **kwargs)

if not hasattr(model.wv, 'vectors_vocab_lockf') and hasattr(model.wv, 'vectors_vocab'):
# TODO: try trainables-location
model.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
if not hasattr(model, 'vectors_ngrams_lockf') and hasattr(model.wv, 'vectors_ngrams'):
# TODO: try trainables-location
model.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)
# fixup mistakenly overdimensioned gensim-3.x lockf arrays
if len(model.wv.vectors_vocab_lockf.shape) > 1:
model.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
if len(model.wv.vectors_ngrams_lockf.shape) > 1:
model.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)
if hasattr(model, 'bucket'):
del model.bucket # should only exist in one place: the wv subcomponent
if not hasattr(model.wv, 'buckets_word') or not model.wv.buckets_word:
model.wv.recalc_char_ngram_buckets()
return super(FastText, cls).load(*args, rethrow=True, **kwargs)

return model
def _load_specials(self, *args, **kwargs):
"""Handle special requirements of `.load()` protocol, usually up-converting older versions."""
super(FastText, self)._load_specials(*args, **kwargs)
if hasattr(self, 'bucket'):
# should only exist in one place: the wv subcomponent
self.wv.bucket = self.bucket
del self.bucket


class FastTextVocab(utils.SaveLoad):
Expand Down Expand Up @@ -1202,12 +1191,49 @@ def __init__(self, vector_size, min_n, max_n, bucket):

@classmethod
def load(cls, fname_or_handle, **kwargs):
model = super(FastTextKeyedVectors, cls).load(fname_or_handle, **kwargs)
if isinstance(model, FastTextKeyedVectors):
if not hasattr(model, 'compatible_hash') or model.compatible_hash is False:
raise TypeError("Pre-gensim-3.8.x Fasttext models with nonstandard hashing are no longer compatible."
"Loading into gensim-3.8.3 & re-saving may create a compatible model.")
return model
"""Load a previously saved `FastTextKeyedVectors` model.
Parameters
----------
fname : str
Path to the saved file.
Returns
-------
:class:`~gensim.models.fasttext.FastTextKeyedVectors`
Loaded model.
See Also
--------
:meth:`~gensim.models.fasttext.FastTextKeyedVectors.save`
Save :class:`~gensim.models.fasttext.FastTextKeyedVectors` model.
"""
return super(FastTextKeyedVectors, cls).load(fname_or_handle, **kwargs)

def _load_specials(self, *args, **kwargs):
"""Handle special requirements of `.load()` protocol, usually up-converting older versions."""
super(FastTextKeyedVectors, self)._load_specials(*args, **kwargs)
if not isinstance(self, FastTextKeyedVectors):
raise TypeError("Loaded object of type %s, not expected FastTextKeyedVectors" % type(self))
if not hasattr(self, 'compatible_hash') or self.compatible_hash is False:
raise TypeError(
"Pre-gensim-3.8.x fastText models with nonstandard hashing are no longer compatible. "
"Loading your old model into gensim-3.8.3 & re-saving may create a model compatible with gensim 4.x."
)
if not hasattr(self, 'vectors_vocab_lockf') and hasattr(self, 'vectors_vocab'):
self.vectors_vocab_lockf = ones(1, dtype=REAL)
if not hasattr(self, 'vectors_ngrams_lockf') and hasattr(self, 'vectors_ngrams'):
self.vectors_ngrams_lockf = ones(1, dtype=REAL)
# fixup mistakenly overdimensioned gensim-3.x lockf arrays
if len(self.vectors_vocab_lockf.shape) > 1:
self.vectors_vocab_lockf = ones(1, dtype=REAL)
if len(self.vectors_ngrams_lockf.shape) > 1:
self.vectors_ngrams_lockf = ones(1, dtype=REAL)
if not hasattr(self, 'buckets_word') or not self.buckets_word:
self.recalc_char_ngram_buckets()
if not hasattr(self, 'vectors') or self.vectors is None:
self.adjust_vectors() # recompose full-word vectors

def __contains__(self, word):
"""Check if `word` or any character ngrams in `word` are present in the vocabulary.
Expand Down Expand Up @@ -1255,14 +1281,15 @@ def save(self, *args, **kwargs):
Load object.
"""
# don't bother storing the cached normalized vectors
ignore_attrs = [
'buckets_word',
'hash2index',
]
kwargs['ignore'] = kwargs.get('ignore', ignore_attrs)
super(FastTextKeyedVectors, self).save(*args, **kwargs)

def _save_specials(self, fname, separately, sep_limit, ignore, pickle_protocol, compress, subname):
"""Arrange any special handling for the gensim.utils.SaveLoad protocol"""
# don't save properties that are merely calculated from others
ignore = set(ignore).union(['buckets_word', 'vectors', ])
return super(FastTextKeyedVectors, self)._save_specials(
fname, separately, sep_limit, ignore, pickle_protocol, compress, subname)

def get_vector(self, word, norm=False):
"""Get `word` representations in vector space, as a 1D numpy array.
Expand Down
Loading

0 comments on commit c6c24ea

Please sign in to comment.