Skip to content

Commit

Permalink
Tweaks for serialization of DTypePolicy in ops/layers. (#19728)
Browse files Browse the repository at this point in the history
- Subclasses of `Operation` / `Layer` which override `__init__` and use the `dtype` parameter don't expect a `dict`. We deserialize the `DTypePolicy` in `from_config`.
- The auto `get_config` feature would break when a `DTypePolicy` was passed to the constructor of any `Operation` or `Layer` subclass not implementing `get_config`.
  • Loading branch information
hertschuh committed May 17, 2024
1 parent 097673f commit 6e40533
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
4 changes: 2 additions & 2 deletions keras/src/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get(identifier):
The `identifier` may be the string name of a `DTypePolicy` class.
>>> policy = dtype_policies.get("mixed_bfloat16")
>>> type(loss)
>>> type(policy)
<class '...FloatDTypePolicy'>
You can also specify `config` of the dtype policy to this function by
Expand All @@ -70,7 +70,7 @@ def get(identifier):
>>> identifier = {"class_name": "FloatDTypePolicy",
... "config": {"name": "float32"}}
>>> policy = dtype_policies.get(identifier)
>>> type(loss)
>>> type(policy)
<class '...FloatDTypePolicy'>
Args:
Expand Down
23 changes: 14 additions & 9 deletions keras/src/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,16 @@ def __new__(cls, *args, **kwargs):
out of the box in most cases without forcing the user
to manually implement `get_config()`.
"""
instance = super(Operation, cls).__new__(cls)

# Generate a config to be returned by default by `get_config()`.
arg_names = inspect.getfullargspec(cls.__init__).args
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
instance = super(Operation, cls).__new__(cls)
if "dtype" in kwargs and isinstance(
kwargs["dtype"], dtype_policies.DTypePolicy
):
kwargs["dtype"] = kwargs["dtype"].get_config()

# For safety, we only rely on auto-configs for a small set of
# serializable types.
supported_types = (str, int, float, bool, type(None))
Expand Down Expand Up @@ -187,20 +193,19 @@ def get_config(self):

@classmethod
def from_config(cls, config):
"""Creates a layer from its config.
"""Creates an operation from its config.
This method is the reverse of `get_config`,
capable of instantiating the same layer from the config
dictionary. It does not handle layer connectivity
(handled by Network), nor weights (handled by `set_weights`).
This method is the reverse of `get_config`, capable of instantiating the
same operation from the config dictionary.
Args:
config: A Python dictionary, typically the
output of get_config.
config: A Python dictionary, typically the output of `get_config`.
Returns:
A layer instance.
An operation instance.
"""
if "dtype" in config and isinstance(config["dtype"], dict):
config["dtype"] = dtype_policies.deserialize(config["dtype"])
try:
return cls(**config)
except Exception as e:
Expand Down

0 comments on commit 6e40533

Please sign in to comment.