Skip to content

Commit

Permalink
Fix object_registration_test.
Browse files Browse the repository at this point in the history
This test was not run at all because it was missing a `main`.
The tests were changed to match the current behavior of object registration.

PiperOrigin-RevId: 651490221
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Jul 11, 2024
1 parent 64108a1 commit 6d0a5e6
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions tf_keras/saving/object_registration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,30 +58,24 @@ def __init__(self, value):
def get_config(self):
return {"value": self._value}

@classmethod
def from_config(cls, config):
return cls(**config)

serialized_name = "Custom>TestClass"
inst = TestClass(value=10)
class_name = object_registration._GLOBAL_CUSTOM_NAMES[TestClass]
self.assertEqual(serialized_name, class_name)
config = serialization_lib.serialize_keras_object(inst)
self.assertEqual(class_name, config["class_name"])
if tf.__internal__.tf2.enabled():
self.assertEqual(class_name, config["registered_name"])
else:
self.assertEqual(class_name, config["class_name"])
new_inst = serialization_lib.deserialize_keras_object(config)
self.assertIsNot(inst, new_inst)
self.assertIsInstance(new_inst, TestClass)
self.assertEqual(10, new_inst._value)

# Make sure registering a new class with same name will fail.
with self.assertRaisesRegex(
ValueError, ".*has already been registered.*"
):

@object_registration.register_keras_serializable()
class TestClass:
def __init__(self, value):
self._value = value

def get_config(self):
return {"value": self._value}

def test_serialize_custom_class_with_custom_name(self):
@object_registration.register_keras_serializable(
"TestPackage", "CustomName"
Expand All @@ -93,6 +87,10 @@ def __init__(self, val):
def get_config(self):
return {"val": self._val}

@classmethod
def from_config(cls, config):
return cls(**config)

serialized_name = "TestPackage>CustomName"
inst = OtherTestClass(val=5)
class_name = object_registration._GLOBAL_CUSTOM_NAMES[OtherTestClass]
Expand All @@ -103,9 +101,12 @@ def get_config(self):
cls = object_registration.get_registered_object(fn_class_name)
self.assertEqual(OtherTestClass, cls)

config = keras.utils.serialization.serialize_keras_object(inst)
self.assertEqual(class_name, config["class_name"])
new_inst = keras.utils.serialization.deserialize_keras_object(config)
config = keras.saving.serialization_lib.serialize_keras_object(inst)
if tf.__internal__.tf2.enabled():
self.assertEqual(class_name, config["registered_name"])
else:
self.assertEqual(class_name, config["class_name"])
new_inst = keras.saving.serialization_lib.deserialize_keras_object(config)
self.assertIsNot(inst, new_inst)
self.assertIsInstance(new_inst, OtherTestClass)
self.assertEqual(5, new_inst._val)
Expand All @@ -121,9 +122,12 @@ def my_fn():
fn_class_name = object_registration.get_registered_name(my_fn)
self.assertEqual(fn_class_name, class_name)

config = keras.utils.serialization.serialize_keras_object(my_fn)
self.assertEqual(class_name, config)
fn = keras.utils.serialization.deserialize_keras_object(config)
config = keras.saving.serialization_lib.serialize_keras_object(my_fn)
if tf.__internal__.tf2.enabled():
self.assertEqual("my_fn", config["config"])
else:
self.assertEqual(class_name, config)
fn = keras.saving.serialization_lib.deserialize_keras_object(config)
self.assertEqual(42, fn())

fn_2 = object_registration.get_registered_object(fn_class_name)
Expand All @@ -141,3 +145,7 @@ def test_serialize_custom_class_without_get_config_fails(self):
class TestClass:
def __init__(self, value):
self._value = value


if __name__ == "__main__":
tf.test.main()

0 comments on commit 6d0a5e6

Please sign in to comment.