diff --git a/tf_keras_vis/utils/__init__.py b/tf_keras_vis/utils/__init__.py index 38a2f4a..3d30fc8 100644 --- a/tf_keras_vis/utils/__init__.py +++ b/tf_keras_vis/utils/__init__.py @@ -135,4 +135,8 @@ def lower_precision_dtype(model): def get_input_names(model): - return [input.name for input in model.inputs] + if version(tf.version.VERSION) >= version("2.4.0"): + names = [input.name for input in model.inputs] + else: + names = model.input_names + return names