diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 732950ca6869..40cacd276c7a 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -238,6 +238,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): # "UNIDIRECTIONAL_SEQUENCE_LSTM": self.convert_unidirectional_sequence_lstm, "WHERE": self.convert_select, "ZEROS_LIKE": self.convert_zeros_like, + "NON_MAX_SUPPRESSION_V4": self.convert_nms_v4, "NON_MAX_SUPPRESSION_V5": self.convert_nms_v5, } @@ -3589,6 +3590,64 @@ def convert_detection_postprocess(self, op): num_detections = relax.op.astype(num_detections, "float32") return relax.Tuple([detection_boxes, detection_classes, detection_scores, num_detections]) + def convert_nms_v4(self, op): + """Convert TFLite NonMaxSuppressionV4""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 5, "input tensor length should be 5" + + boxes = self.get_tensor_expr(input_tensors[0]) + scores = self.get_tensor_expr(input_tensors[1]) + + max_output_size = self.get_tensor_value(input_tensors[2]) + iou_threshold = self.get_tensor_value(input_tensors[3]) + score_threshold = self.get_tensor_value(input_tensors[4]) + + if isinstance(max_output_size, np.ndarray): + assert max_output_size.size == 1, "only one value is expected." + max_output_size = int(max_output_size) + + if isinstance(iou_threshold, np.ndarray): + assert iou_threshold.size == 1, "only one value is expected." + iou_threshold = float(iou_threshold) + + if isinstance(score_threshold, np.ndarray): + assert score_threshold.size == 1, "only one value is expected." + score_threshold = float(score_threshold) + + scores_expand = relax.op.expand_dims(scores, axis=-1) + data = relax.op.concat([scores_expand, boxes], axis=-1) + data = relax.op.expand_dims(data, axis=0) + + valid_counts_ret = relax.op.vision.get_valid_counts( + data, score_threshold=score_threshold, id_index=-1, score_index=0 + ) + count = valid_counts_ret[0] + data = valid_counts_ret[1] + indices = valid_counts_ret[2] + + nms_ret = relax.op.vision.non_max_suppression( + data=data, + valid_count=count, + indices=indices, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=-1, + coord_start=1, + score_index=0, + id_index=-1, + return_indices=True, + invalid_to_bottom=False, + ) + + selected_indices = relax.op.squeeze(nms_ret[0], axis=[0]) + selected_indices = relax.op.strided_slice( + selected_indices, axes=[0], begin=[0], end=[max_output_size] + ) + num_valid = relax.op.reshape(nms_ret[1], []) + + return relax.Tuple([selected_indices, num_valid]) + def convert_nms_v5(self, op): """Convert TFLite NonMaxSuppressionV5""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 908868faf0c4..dde62fde2868 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1360,6 +1360,59 @@ def main( verify(BatchMatMulAdj, Expected) +def _verify_nms_v4(mod, tf_func, boxes_np, scores_np): + """E2E verify for NMS V4: only run on nightly, compare valid outputs only.""" + if "CI_ENV_NIGHTLY" not in os.environ: + return + + tf_indices, tf_valid = tf_func(tf.constant(boxes_np), tf.constant(scores_np)) + n_valid = int(tf_valid.numpy()) + + tgt = tvm.target.Target("llvm") + ex = tvm.compile(mod, tgt) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main", boxes_np, scores_np) + vm.invoke_stateful("main") + tvm_indices, tvm_valid = vm.get_outputs("main") + + assert int(tvm_valid.numpy()) == n_valid + np.testing.assert_array_equal( + tf_indices.numpy()[:n_valid], + tvm_indices.numpy()[:n_valid], + ) + + +def _build_nms_v4_mod(num_boxes, max_output_size, iou_threshold, score_threshold): + """Convert a NonMaxSuppressionV4 TFLite model to a Relax module. + + Scalar params must be Python literals (not tf.constant) so TFLite can + statically infer output shapes during conversion. + """ + + class NMSv4Module(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(num_boxes, 4), dtype=tf.float32), + tf.TensorSpec(shape=(num_boxes,), dtype=tf.float32), + ] + ) + def func(self, boxes, scores): + indices, valid = tf.raw_ops.NonMaxSuppressionV4( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True, + ) + return indices, valid + + instance = NMSv4Module() + cf = instance.func.get_concrete_function() + mod = _get_mod_from_cfunc(cf) + return mod, instance.func + + def _verify_nms_v5(mod, tf_func, boxes_np, scores_np, soft_nms_sigma=0.0): """E2E verify for NMS: only run on nightly, compare valid outputs only.""" if "CI_ENV_NIGHTLY" not in os.environ: @@ -1804,6 +1857,100 @@ def test_nms_v5_soft_ir(): assert "R.clip(" in ir +_NMS_V4_CASES = [ + pytest.param( + 6, + 3, + 0.5, + 0.0, + np.array( + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, 0.0, 1.0, 0.9], + [0.5, 0.5, 1.5, 1.5], + [0.0, 0.0, 0.3, 0.3], + ], + dtype=np.float32, + ), + np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32), + id="basic", + ), + pytest.param( + 8, + 4, + 0.5, + 0.4, + _make_valid_boxes(np.random.default_rng(42), 8), + np.random.default_rng(42).random(8, dtype=np.float32), + id="score_threshold", + ), + pytest.param( + 5, + 3, + 0.5, + 0.99, + _make_valid_boxes(np.random.default_rng(0), 5), + np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32), + id="all_suppressed", + ), + pytest.param( + 4, + 10, + 0.5, + 0.0, + np.array( + [ + [0.0, 0.0, 0.3, 0.3], + [0.5, 0.5, 0.8, 0.8], + [0.1, 0.1, 0.4, 0.4], + [0.6, 0.6, 0.9, 0.9], + ], + dtype=np.float32, + ), + np.array([0.9, 0.85, 0.7, 0.65], dtype=np.float32), + id="max_output_size_larger_than_boxes", + ), +] + + +@pytest.mark.parametrize( + "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores", + _NMS_V4_CASES, +) +def test_nms_v4(num_boxes, max_output_size, iou_threshold, score_threshold, boxes, scores): + """NON_MAX_SUPPRESSION_V4: conversion smoke test + E2E correctness (nightly only).""" + mod, tf_func = _build_nms_v4_mod(num_boxes, max_output_size, iou_threshold, score_threshold) + _verify_nms_v4(mod, tf_func, boxes, scores) + + +def test_nms_v4_ir(): + """Verify the emitted Relax IR has correct structure for NON_MAX_SUPPRESSION_V4.""" + num_boxes = 6 + max_output_size = 3 + mod, _ = _build_nms_v4_mod( + num_boxes=num_boxes, + max_output_size=max_output_size, + iou_threshold=0.5, + score_threshold=0.0, + ) + + ir = mod.script() + + # Validate correct sorting/id indices are passed to valid_counts + assert "score_index=0" in ir + assert "id_index=-1" in ir + # NMS size limit validation + assert f"max_output_size={max_output_size}" in ir + # Valid output shape must be () statically + assert 'R.Tensor((), dtype="int32")' in ir + # Selected indices tensor bounds check + assert f"R.Tensor(({max_output_size},)" in ir + # V4 must use hard-NMS (soft_nms_sigma left at default 0.0) + assert "soft_nms_sigma=0.0" in ir + + _DETECTION_POSTPROCESS_SMOKE_CASES = [ pytest.param( {