diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 732950ca6869..f60621fdda72 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -124,6 +124,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"), "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, "BATCH_MATMUL": self.convert_batch_matmul, + "BITCAST": self.convert_bitcast, "CAST": self.convert_cast, "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, @@ -2441,6 +2442,28 @@ def convert_reverse_sequence(self, op): return relax.op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis) + def convert_bitcast(self, op): + """Convert TFLite BITCAST""" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + assert len(output_tensors) == 1, "output tensors length should be 1" + + in_expr = self.get_tensor_expr(input_tensors[0]) + input_dtype = self.get_tensor_type_str(input_tensors[0].tensor.Type()) + output_dtype = self.get_tensor_type_str(output_tensors[0].tensor.Type()) + input_shape = to_int_list(self.get_tensor_shape(input_tensors[0])) + output_shape = to_int_list(self.get_tensor_shape(output_tensors[0])) + + input_nbytes = int(np.prod(input_shape)) * np.dtype(input_dtype).itemsize + output_nbytes = int(np.prod(output_shape)) * np.dtype(output_dtype).itemsize + assert input_nbytes == output_nbytes, ( + "TFLite BITCAST requires input.nbytes == output.nbytes, " + f"but got input={input_nbytes} bytes, output={output_nbytes} bytes" + ) + + return relax.op.memory.view(in_expr, shape=output_shape, dtype=output_dtype) + def convert_cast(self, op): """Convert TFLite CAST""" diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 908868faf0c4..6786e4014cf9 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -256,6 +256,94 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="int3 verify(Cast, Expected) +def test_bitcast_float32_to_int32(): + """BITCAST same-width: float32 -> int32, shape preserved.""" + + class BitcastF32ToI32(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return tf.bitcast(x, tf.int32) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="int32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((1, 30), dtype="int32") = R.memory.view( + x, R.shape([1, 30]), R.dtype("int32") + ) + R.output(gv) + return gv + + verify(BitcastF32ToI32, Expected) + + +def test_bitcast_uint8_to_int8(): + """BITCAST same-width 8-bit: uint8 -> int8.""" + + class BitcastU8ToI8(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(4,), dtype=tf.uint8)]) + def func(self, x): + return tf.bitcast(x, tf.int8) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4,), dtype="uint8")) -> R.Tensor((4,), dtype="int8"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((4,), dtype="int8") = R.memory.view(x, R.shape([4]), R.dtype("int8")) + R.output(gv) + return gv + + verify(BitcastU8ToI8, Expected) + + +def test_bitcast_int32_to_int16_widens_shape(): + """BITCAST width-changing (smaller): int32[3] -> int16[3, 2].""" + + class BitcastI32ToI16(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(3,), dtype=tf.int32)]) + def func(self, x): + return tf.bitcast(x, tf.int16) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 2), dtype="int16"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((3, 2), dtype="int16") = R.memory.view( + x, R.shape([3, 2]), R.dtype("int16") + ) + R.output(gv) + return gv + + verify(BitcastI32ToI16, Expected) + + +def test_bitcast_int16_to_int32_collapses_shape(): + """BITCAST width-changing (larger): int16[5, 2] -> int32[5].""" + + class BitcastI16ToI32(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(5, 2), dtype=tf.int16)]) + def func(self, x): + return tf.bitcast(x, tf.int32) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((5, 2), dtype="int16")) -> R.Tensor((5,), dtype="int32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((5,), dtype="int32") = R.memory.view(x, R.shape([5]), R.dtype("int32")) + R.output(gv) + return gv + + verify(BitcastI16ToI32, Expected) + + def test_expand_dims(): class ExpandDims(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])