Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion ggml/src/ggml-openvino/openvino/op/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
#include "../op_table.h"
#include "../utils.h"

#include "ggml.h"

#include <memory>
#include <openvino/op/util/precision_sensitive_attribute.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/divide.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/sigmoid.hpp>
#include <openvino/op/tile.hpp>
#include <vector>

Expand All @@ -17,6 +22,36 @@ namespace op {

namespace {

bool is_silu_div_pattern(const ov::Output<ov::Node> & numerator,
const ov::Output<ov::Node> & denominator,
const NodeContext & context) {
if (context.get_input_size() != 2) {
return false;
}

const auto * unary_op = reinterpret_cast<const ggml_unary_op *>(context.get_input_op_params(0));
if (unary_op == nullptr || *unary_op != GGML_UNARY_OP_SILU) {
return false;
}

auto mul = std::dynamic_pointer_cast<ov::op::v1::Multiply>(numerator.get_node_shared_ptr());
if (!mul) {
return false;
}

const auto denom_node = denominator.get_node_shared_ptr();
const auto mul_input_0 = mul->input_value(0).get_node_shared_ptr();
const auto mul_input_1 = mul->input_value(1).get_node_shared_ptr();

auto sigmoid = std::dynamic_pointer_cast<ov::op::v0::Sigmoid>(mul_input_1);
if (mul_input_0 == denom_node && sigmoid && sigmoid->input_value(0).get_node_shared_ptr() == denom_node) {
return true;
}

sigmoid = std::dynamic_pointer_cast<ov::op::v0::Sigmoid>(mul_input_0);
return mul_input_1 == denom_node && sigmoid && sigmoid->input_value(0).get_node_shared_ptr() == denom_node;
}

ov::Output<ov::Node> repeat_input_to_match(const NodeContext & context,
const ov::Output<ov::Node> & input,
const ov::Output<ov::Node> & target,
Expand Down Expand Up @@ -68,6 +103,15 @@ OutputVector translate_div(const NodeContext & context) {

auto input_0 = process_view_input_new(context, 0);
auto input_1 = process_view_input_new(context, 1);

if (is_silu_div_pattern(input_0, input_1, context)) {
ov::Output<ov::Node> res = std::make_shared<ov::op::v0::Sigmoid>(input_1);
if (res.get_element_type() != context.get_output_type()) {
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type());
}
return rename_outputs_with_suffix({res}, context.get_name());
}

input_1 = repeat_input_to_match(context, input_1, input_0, 1);

const auto output_type = context.get_output_type();
Expand All @@ -81,8 +125,19 @@ OutputVector translate_div(const NodeContext & context) {
}

ov::Output<ov::Node> res = std::make_shared<ov::op::v1::Divide>(input_0, input_1);
if (use_f32_compute) {
// Keep the reciprocal/divide path in FP32. Without this hint, the GPU
// plugin can still compress the subgraph back to FP16 and overflow on
// small shexp gate values (e.g. silu(x) / x in qwen2moe).
ov::mark_as_precision_sensitive(res.get_node_shared_ptr()->input(0));
ov::mark_as_precision_sensitive(res.get_node_shared_ptr()->input(1));
}
if (res.get_element_type() != output_type) {
res = std::make_shared<ov::op::v0::Convert>(res, output_type);
auto output_convert = std::make_shared<ov::op::v0::Convert>(res, output_type);
if (use_f32_compute) {
ov::mark_as_precision_sensitive(output_convert->input(0));
}
res = output_convert;
}
return rename_outputs_with_suffix({res}, context.get_name());
}
Expand Down
Loading