Skip to content

Commit

Permalink
Update universal_sentence_encoder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yangheng95 committed Aug 30, 2023
1 parent d5bb399 commit b2c8bad
Showing 1 changed file with 3 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,16 @@ def __init__(self, threshold=0.8, large=False, metric="angular", **kwargs):
super().__init__(threshold=threshold, metric=metric, **kwargs)
if large:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
mirror_tfhub_url = "https://hub.tensorflow.google.cn/google/universal-sentence-encoder-large/5"
else:
tfhub_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
mirror_tfhub_url = (
"https://hub.tensorflow.google.cn/google/universal-sentence-encoder/4"
)

self._tfhub_url = tfhub_url
self.mirror_tfhub_url = mirror_tfhub_url
# Lazily load the model
self.model = None

def encode(self, sentences):
if not self.model:
try:
self.model = hub.load(self._tfhub_url)
except Exception as e:
print('Error loading model from tfhub, trying mirror url')
self.model = hub.load(self.mirror_tfhub_url)
self.model = hub.load(self._tfhub_url)
return self.model(sentences).numpy()

def __getstate__(self):
Expand All @@ -46,8 +37,5 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state
try:
self.model = hub.load(self._tfhub_url)
except Exception as e:
print('Error loading model from tfhub, trying mirror url')
self.model = hub.load(self.mirror_tfhub_url)
self.model = hub.load(self._tfhub_url)

0 comments on commit b2c8bad

Please sign in to comment.