Skip to content

Commit

Permalink
Add tests (#19729)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed May 18, 2024
1 parent 20bc267 commit a05ac12
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras/src/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def get(identifier):

if identifier is None:
return dtype_policy.dtype_policy()
if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)):
if isinstance(
identifier, (DTypePolicy, FloatDTypePolicy, QuantizedDTypePolicy)
):
return identifier
if isinstance(identifier, dict):
return deserialize(identifier)
Expand Down
12 changes: 12 additions & 0 deletions keras/src/dtype_policies/dtype_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,13 +570,25 @@ def test_get_valid_policy(self):
policy = get("mixed_float16")
self.assertEqual(policy.name, "mixed_float16")

policy = get(DTypePolicy("bfloat16"))
self.assertEqual(policy.name, "bfloat16")

policy = get(FloatDTypePolicy("mixed_float16"))
self.assertEqual(policy.name, "mixed_float16")

def test_get_valid_policy_quantized(self):
policy = get("int8_from_mixed_bfloat16")
self.assertEqual(policy.name, "int8_from_mixed_bfloat16")

policy = get("float8_from_float32")
self.assertEqual(policy.name, "float8_from_float32")

policy = get(QuantizedDTypePolicy("int8", "mixed_bfloat16"))
self.assertEqual(policy.name, "int8_from_mixed_bfloat16")

policy = get(QuantizedFloat8DTypePolicy("float8", "mixed_float16"))
self.assertEqual(policy.name, "float8_from_mixed_float16")

def test_get_invalid_policy(self):
with self.assertRaisesRegex(ValueError, "Cannot convert"):
get("mixed_bfloat15")
Expand Down
20 changes: 20 additions & 0 deletions keras/src/ops/operation_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from keras.src import backend
from keras.src import dtype_policies
from keras.src import testing
from keras.src.backend.common import keras_tensor
from keras.src.ops import numpy as knp
Expand Down Expand Up @@ -43,6 +44,17 @@ def compute_output_spec(self, x):
return keras_tensor.KerasTensor(x.shape, x.dtype)


class OpWithCustomDtype(operation.Operation):
def __init__(self, dtype):
super().__init__(dtype=dtype)

def call(self, x):
return x

def compute_output_spec(self, x):
return keras_tensor.KerasTensor(x.shape, x.dtype)


class OperationTest(testing.TestCase):
def test_symbolic_call(self):
x = keras_tensor.KerasTensor(shape=(2, 3), name="x")
Expand Down Expand Up @@ -160,3 +172,11 @@ def test_valid_naming(self):
ValueError, "must be a string and cannot contain character `/`."
):
OpWithMultipleOutputs(name="test/op")

def test_dtype(self):
op = OpWithCustomDtype(dtype="bfloat16")
self.assertEqual(op._dtype_policy.name, "bfloat16")

policy = dtype_policies.DTypePolicy("mixed_bfloat16")
op = OpWithCustomDtype(dtype=policy)
self.assertEqual(op._dtype_policy.name, "mixed_bfloat16")

0 comments on commit a05ac12

Please sign in to comment.