Skip to content

Commit

Permalink
Fixes for Sequential model with multiple inputs.
Browse files Browse the repository at this point in the history
- While `Sequential` works with multiple inputs in most scenarios, `build()` did not allow building with multiple inputs. This is now fixed.
- Removed the `build_input_shape` from the new serialization format. This is a legacy concept, which has been replaced with `build_config.input_shape` in the new format. Having both could cause models to be built twice.
- `build_from_config` now always call `build` with `TensorShape`s, not tuples. Not all layers handle tuples correctly.

PiperOrigin-RevId: 716010092
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Jan 23, 2025
1 parent a23abb2 commit cd0a0d7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tf_keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,7 +2321,7 @@ def build_from_config(self, config):
"""
input_shape = config["input_shape"]
if input_shape is not None:
self.build(input_shape)
self.build(tf_utils.convert_shapes(input_shape, to_tuples=False))

############################################################################
# Methods & attributes below are all private and only used by the framework.
Expand Down
41 changes: 29 additions & 12 deletions tf_keras/engine/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,16 @@ def _build_graph_network_for_inferred_shape(
):
# Determine whether the input shape is novel, i.e. whether the model
# should be rebuilt.
input_shape = tuple(input_shape)
input_shape = tf_utils.convert_shapes(input_shape)
if self._inferred_input_shape is None:
new_shape = input_shape
else:
new_shape = relax_input_shape(
self._inferred_input_shape, input_shape
new_shape = tf.nest.map_structure(
_relax_input_shape,
tf_utils.convert_shapes(
self._inferred_input_shape, to_tuples=False
),
tf_utils.convert_shapes(input_shape, to_tuples=False),
)
if (
new_shape is not None
Expand All @@ -299,10 +303,13 @@ def _build_graph_network_for_inferred_shape(
# A novel shape has been received: we need to rebuild the model.
# In case we are inside a graph function, we step out of it.
with tf.init_scope():
inputs = input_layer.Input(
batch_shape=new_shape,
dtype=input_dtype,
name=self.layers[0].name + "_input",
inputs = tf.nest.map_structure(
lambda s: input_layer.Input(
batch_shape=tf_utils.convert_shapes(s),
dtype=input_dtype,
name=self.layers[0].name + "_input",
),
tf_utils.convert_shapes(new_shape, to_tuples=False),
)
layer_input = inputs
created_nodes = set()
Expand Down Expand Up @@ -370,7 +377,7 @@ def build(self, input_shape=None):
raise ValueError("You must provide an `input_shape` argument.")
self._build_graph_network_for_inferred_shape(input_shape)
if not self.built:
input_shape = tuple(input_shape)
input_shape = tf_utils.convert_shapes(input_shape)
self._build_input_shape = input_shape
super().build(input_shape)
self.built = True
Expand Down Expand Up @@ -435,7 +442,8 @@ def compute_mask(self, inputs, mask):
def get_config(self):
layer_configs = []
serialize_obj_fn = serialization_lib.serialize_keras_object
if getattr(self, "use_legacy_config", None):
use_legacy_config = getattr(self, "use_legacy_config", False)
if use_legacy_config:
serialize_obj_fn = legacy_serialization.serialize_keras_object
for layer in super().layers:
# `super().layers` include the InputLayer if available (it is
Expand All @@ -446,7 +454,11 @@ def get_config(self):
config = training.Model.get_config(self)
config["name"] = self.name
config["layers"] = copy.deepcopy(layer_configs)
if not self._is_graph_network and self._build_input_shape is not None:
if (
use_legacy_config
and not self._is_graph_network
and self._build_input_shape
):
config["build_input_shape"] = self._build_input_shape
return config

Expand All @@ -458,6 +470,7 @@ def from_config(cls, config, custom_objects=None):
layer_configs = config["layers"]
else:
name = None
build_input_shape = None
layer_configs = config
model = cls(name=name)
for layer_config in layer_configs:
Expand Down Expand Up @@ -519,11 +532,15 @@ def _get_shape_tuple(t):
return None


def relax_input_shape(shape_1, shape_2):
def _relax_input_shape(shape_1, shape_2):
if shape_1 is None or shape_2 is None:
return None
if len(shape_1) != len(shape_2):
if shape_1.rank is None or shape_2.rank is None:
return None
if shape_1.rank != shape_2.rank:
return None
shape_1 = shape_1.as_list()
shape_2 = shape_2.as_list()
return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))


Expand Down
27 changes: 27 additions & 0 deletions tf_keras/engine/sequential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized

import tf_keras as keras
from tf_keras.saving import object_registration
from tf_keras.testing_infra import test_combinations
from tf_keras.testing_infra import test_utils

Expand Down Expand Up @@ -574,6 +575,22 @@ def test_multi_inputs_outputs(self):
model(image_inputs)
model.fit(x=image_inputs, y=image_inputs, steps_per_epoch=1)

@test_combinations.run_all_keras_modes(always_skip_v1=True)
def test_multi_inputs_build(self):
model = keras.Sequential([ImageMultiplyLayer()])
model.build({"images": (None, 512, 512, 3), "weights": (None, 3)})

image_inputs = tf.ones((2, 512, 512, 3))
weight_inputs = tf.ones((2, 3))
output = model({"images": image_inputs, "weights": weight_inputs})

config = model.to_json()
new_model = keras.models.model_from_json(config)
new_output = new_model(
{"images": image_inputs, "weights": weight_inputs}
)
self.assertAllClose(output, new_output)


class TestSequentialEagerIntegration(test_combinations.TestCase):
@test_combinations.run_all_keras_modes
Expand Down Expand Up @@ -642,10 +659,20 @@ def test_build_empty_network(self):
self.assertTrue(model.built)


@object_registration.register_keras_serializable()
class ImageAugmentLayer(keras.layers.Layer):
def call(self, inputs):
return inputs


@object_registration.register_keras_serializable()
class ImageMultiplyLayer(keras.layers.Layer):
def call(self, inputs):
images = inputs["images"]
weights = inputs["weights"]
images = tf.reshape(images, (-1, 1, 1, 3))
return images * weights


if __name__ == "__main__":
tf.test.main()

0 comments on commit cd0a0d7

Please sign in to comment.