Skip to content

Commit

Permalink
Fix export for subclass models with multiple inputs. (#19720)
Browse files Browse the repository at this point in the history
The export now supports subclasses of `Model` for which the `call` method takes more than one input argument. Note that it is required for the model class to implement a `build` method with a signature that matches the `call` method.
  • Loading branch information
hertschuh committed May 17, 2024
1 parent 6e40533 commit 20bc267
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
18 changes: 4 additions & 14 deletions keras/src/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,18 +621,17 @@ def export_model(model, filepath):
input_signature = [input_signature]
export_archive.add_endpoint("serve", model.__call__, input_signature)
else:
save_spec = _get_save_spec(model)
if not save_spec or not model._called:
input_signature = _get_input_signature(model)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)
input_signature = [save_spec]
export_archive.add_endpoint("serve", model.__call__, input_signature)
export_archive.write_out(filepath)


def _get_save_spec(model):
def _get_input_signature(model):
shapes_dict = getattr(model, "_build_shapes_dict", None)
if not shapes_dict:
return None
Expand All @@ -654,16 +653,7 @@ def make_tensor_spec(structure):
f"Unsupported type {type(structure)} for {structure}"
)

if len(shapes_dict) == 1:
value = list(shapes_dict.values())[0]
return make_tensor_spec(value)

specs = {}
for key, value in shapes_dict.items():
key = key.rstrip("_shape")
specs[key] = make_tensor_spec(value)

return specs
return [make_tensor_spec(value) for value in shapes_dict.values()]


@keras_export("keras.layers.TFSMLayer")
Expand Down
25 changes: 25 additions & 0 deletions keras/src/export/export_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,31 @@ def call(self, inputs):
)
revived_model.serve(bigger_input)

def test_model_with_multiple_inputs(self):

class TwoInputsModel(models.Model):
def call(self, x, y):
return x + y

def build(self, y_shape, x_shape):
self.built = True

temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = TwoInputsModel()
ref_input_x = tf.random.normal((3, 10))
ref_input_y = tf.random.normal((3, 10))
ref_output = model(ref_input_x, ref_input_y)

export_lib.export_model(model, temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(
ref_output, revived_model.serve(ref_input_x, ref_input_y)
)
# Test with a different batch size
revived_model.serve(
tf.random.normal((6, 10)), tf.random.normal((6, 10))
)

@parameterized.named_parameters(
named_product(model_type=["sequential", "functional", "subclass"])
)
Expand Down

1 comment on commit 20bc267

@Pawandeepnatt
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

****

Please sign in to comment.