Skip to content

Commit

Permalink
Improve support for remote loading and saving.
Browse files Browse the repository at this point in the history
- `_support_gcs_uri` did not work for saving APIs. It would create a temp folder, but the file created in the temp folder would never actually be copied to the remote path. Introduced a `SupportWriteToRemote`, a context manager that copies the temp file to remote on exit and cleans up the temp folder.
- there were several different places handling remote loading, but none of them would clean up the temp folders created. Introduced a `SupportWriteToRemote`, a context manager that copies the remote file to a temp folder and cleans up on exit.
- removed ad-hoc support for remote zip file saving, instead, this is handled the same way as other remote saving cases with a temp file, which removes the memory requirement.

PiperOrigin-RevId: 700043742
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Nov 27, 2024
1 parent 916ca64 commit 9d46c94
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 142 deletions.
290 changes: 164 additions & 126 deletions tf_keras/saving/saving_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Public API surface for saving APIs."""

import os
import tempfile
import warnings
import zipfile

Expand All @@ -33,17 +34,76 @@
is_oss = True


def _support_gcs_uri(filepath, save_format, is_oss):
"""Supports GCS URIs through bigstore via a temporary file."""
gs_filepath = None
if str(filepath).startswith("gs://") and save_format != "tf":
gs_filepath = filepath
if not is_oss:
gs_filepath = filepath.replace("gs://", "/bigstore/")
filepath = os.path.join(
saving_lib.get_temp_dir(), os.path.basename(gs_filepath)
)
return gs_filepath, filepath
class SupportReadFromRemote:
"""Supports GCS URIs and other remote paths via a temporary file.
This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
supports remoted saved model out of the box.
"""

def __init__(self, filepath):
save_format = get_save_format(filepath, save_format=None)
if (
saving_lib.is_remote_path(filepath)
and not tf.io.gfile.isdir(filepath)
and save_format != "tf"
):
self.temp_directory = tempfile.TemporaryDirectory()
gs_filepath = filepath
if not is_oss and str(filepath).startswith("gs://"):
gs_filepath = filepath.replace("gs://", "/bigstore/")
self.local_filepath = os.path.join(
self.temp_directory.name, os.path.basename(filepath)
)
tf.io.gfile.copy(gs_filepath, self.local_filepath, overwrite=True)
else:
self.temp_directory = None
self.local_filepath = filepath

def __enter__(self):
return self.local_filepath

def __exit__(self, exc_type, exc_value, traceback):
if self.temp_directory is not None:
self.temp_directory.cleanup()


class SupportWriteToRemote:
"""Supports GCS URIs and other remote paths via a temporary file.
This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
supports remoted saved model out of the box.
"""

def __init__(self, filepath, overwrite=True, save_format=None):
save_format = get_save_format(filepath, save_format=save_format)
self.overwrite = overwrite
if saving_lib.is_remote_path(filepath) and save_format != "tf":
self.temp_directory = tempfile.TemporaryDirectory()
self.remote_filepath = filepath
if not is_oss and str(filepath).startswith("gs://"):
self.remote_filepath = self.remote_filepath.replace(
"gs://", "/bigstore/"
)
self.local_filepath = os.path.join(
self.temp_directory.name, os.path.basename(filepath)
)
else:
self.temp_directory = None
self.remote_filepath = None
self.local_filepath = filepath

def __enter__(self):
return self.local_filepath

def __exit__(self, exc_type, exc_value, traceback):
if self.temp_directory is not None:
tf.io.gfile.copy(
self.local_filepath,
self.remote_filepath,
overwrite=self.overwrite,
)
self.temp_directory.cleanup()


@keras_export("keras.saving.save_model", "keras.models.save_model")
Expand Down Expand Up @@ -131,46 +191,49 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
when loading the model. See the `custom_objects` argument in
`tf.keras.saving.load_model`.
"""
save_format = get_save_format(filepath, save_format)

# Supports GCS URIs through bigstore via a temporary file
gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)

# Deprecation warnings
if save_format == "h5":
warnings.warn(
"You are saving your model as an HDF5 file via `model.save()`. "
"This file format is considered legacy. "
"We recommend using instead the native TF-Keras format, "
"e.g. `model.save('my_model.keras')`.",
stacklevel=2,
)

if save_format == "keras":
# If file exists and should not be overwritten.
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
if kwargs:
raise ValueError(
"The following argument(s) are not supported "
f"with the native TF-Keras format: {list(kwargs.keys())}"
)
saving_lib.save_model(model, filepath)
else:
# Legacy case
return legacy_sm_saving_lib.save_model(
model,
# Supports remote paths via a temporary file
with SupportWriteToRemote(
filepath,
overwrite=overwrite,
save_format=save_format,
**kwargs,
)
) as local_filepath:
save_format = get_save_format(filepath, save_format)

# Deprecation warnings
if save_format == "h5":
warnings.warn(
"You are saving your model as an HDF5 file via `model.save()`. "
"This file format is considered legacy. "
"We recommend using instead the native TF-Keras format, "
"e.g. `model.save('my_model.keras')`.",
stacklevel=2,
)

if save_format == "keras":
# If file exists and should not be overwritten.
try:
exists = os.path.exists(local_filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
if not proceed:
return
if kwargs:
raise ValueError(
"The following argument(s) are not supported "
f"with the native TF-Keras format: {list(kwargs.keys())}"
)
saving_lib.save_model(model, local_filepath)
else:
# Legacy case
return legacy_sm_saving_lib.save_model(
model,
local_filepath,
overwrite=overwrite,
save_format=save_format,
**kwargs,
)


@keras_export("keras.saving.load_model", "keras.models.load_model")
Expand Down Expand Up @@ -217,94 +280,69 @@ def load_model(
It is recommended that you use layer attributes to
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
"""
# Supports GCS URIs by copying data to temporary file
save_format = get_save_format(filepath, save_format=None)
gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
if gs_filepath is not None:
tf.io.gfile.copy(gs_filepath, filepath, overwrite=True)

is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
filepath
)

# Support for remote zip files
if (
saving_lib.is_remote_path(filepath)
and not tf.io.gfile.isdir(filepath)
and not is_keras_zip
):
local_path = os.path.join(
saving_lib.get_temp_dir(), os.path.basename(filepath)
)

# Copy from remote to temporary local directory
tf.io.gfile.copy(filepath, local_path, overwrite=True)

# Switch filepath to local zipfile for loading model
if zipfile.is_zipfile(local_path):
filepath = local_path
is_keras_zip = True

if is_keras_zip:
if kwargs:
raise ValueError(
"The following argument(s) are not supported "
f"with the native TF-Keras format: {list(kwargs.keys())}"
# Supports remote paths via a temporary file
with SupportReadFromRemote(filepath) as local_filepath:
if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
local_filepath
):
if kwargs:
raise ValueError(
"The following argument(s) are not supported "
f"with the native TF-Keras format: {list(kwargs.keys())}"
)
return saving_lib.load_model(
local_filepath,
custom_objects=custom_objects,
compile=compile,
safe_mode=safe_mode,
)
return saving_lib.load_model(
filepath,

# Legacy case.
return legacy_sm_saving_lib.load_model(
local_filepath,
custom_objects=custom_objects,
compile=compile,
safe_mode=safe_mode,
**kwargs,
)

# Legacy case.
return legacy_sm_saving_lib.load_model(
filepath, custom_objects=custom_objects, compile=compile, **kwargs
)


def save_weights(model, filepath, overwrite=True, **kwargs):
# Supports GCS URIs through bigstore via a temporary file
save_format = get_save_format(filepath, save_format=None)
gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)

if str(filepath).endswith(".weights.h5"):
# If file exists and should not be overwritten.
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
saving_lib.save_weights_only(model, filepath)
else:
legacy_sm_saving_lib.save_weights(
model, filepath, overwrite=overwrite, **kwargs
)
# Supports remote paths via a temporary file
with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
if str(local_filepath).endswith(".weights.h5"):
# If file exists and should not be overwritten.
try:
exists = os.path.exists(local_filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
if not proceed:
return
saving_lib.save_weights_only(model, local_filepath)
else:
legacy_sm_saving_lib.save_weights(
model, local_filepath, overwrite=overwrite, **kwargs
)


def load_weights(model, filepath, skip_mismatch=False, **kwargs):
# Supports GCS URIs by copying data to temporary file
save_format = get_save_format(filepath, save_format=None)
gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
if gs_filepath is not None:
tf.io.gfile.copy(gs_filepath, filepath, overwrite=True)

if str(filepath).endswith(".keras") and zipfile.is_zipfile(filepath):
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
)
elif str(filepath).endswith(".weights.h5"):
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
)
else:
return legacy_sm_saving_lib.load_weights(
model, filepath, skip_mismatch=skip_mismatch, **kwargs
)
# Supports remote paths via a temporary file
with SupportReadFromRemote(filepath) as local_filepath:
if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
local_filepath
):
saving_lib.load_weights_only(
model, local_filepath, skip_mismatch=skip_mismatch
)
elif str(local_filepath).endswith(".weights.h5"):
saving_lib.load_weights_only(
model, local_filepath, skip_mismatch=skip_mismatch
)
else:
return legacy_sm_saving_lib.load_weights(
model, local_filepath, skip_mismatch=skip_mismatch, **kwargs
)


def get_save_format(filepath, save_format):
Expand Down
12 changes: 1 addition & 11 deletions tf_keras/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,8 @@ def save_model(model, filepath, weights_format="h5"):
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
}
)
# TODO(rameshsampath): Need a better logic for local vs remote path
if is_remote_path(filepath):
# Remote path. Zip to local memory byte io and copy to remote
zip_filepath = io.BytesIO()
else:
zip_filepath = filepath
try:
with zipfile.ZipFile(zip_filepath, "w") as zf:
with zipfile.ZipFile(filepath, "w") as zf:
with zf.open(_METADATA_FILENAME, "w") as f:
f.write(metadata_json.encode())
with zf.open(_CONFIG_FILENAME, "w") as f:
Expand Down Expand Up @@ -195,10 +189,6 @@ def save_model(model, filepath, weights_format="h5"):
)
weights_store.close()
asset_store.close()

if is_remote_path(filepath):
with tf.io.gfile.GFile(filepath, "wb") as f:
f.write(zip_filepath.getvalue())
except Exception as e:
raise e
finally:
Expand Down
12 changes: 7 additions & 5 deletions tf_keras/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,24 +533,26 @@ def test_metadata(self):
self.assertIn("keras_version", metadata)
self.assertIn("date_saved", metadata)

def test_gfile_local_called(self):
def test_gfile_copy_called(self):
temp_filepath = Path(
os.path.join(self.get_temp_dir(), "my_model.keras")
)
model = CompileOverridingModel()
with mock.patch(
"re.match", autospec=True
) as mock_re_match, mock.patch.object(
tf.io.gfile, "GFile"
) as mock_gfile:
tf.io.gfile, "copy"
) as mock_gfile_copy:
# Check regex matching
mock_re_match.return_value = True
model.save(temp_filepath, save_format="keras_v3")
mock_re_match.assert_called()
self.assertIn(str(temp_filepath), mock_re_match.call_args.args)

# Check gfile opened with filepath specified
self.assertIn(str(temp_filepath), mock_gfile.call_args.args)
# Check gfile copied with filepath specified as destination
self.assertEqual(
str(temp_filepath), str(mock_gfile_copy.call_args.args[1])
)

def test_load_model_api_endpoint(self):
temp_filepath = Path(os.path.join(self.get_temp_dir(), "mymodel.keras"))
Expand Down

0 comments on commit 9d46c94

Please sign in to comment.