Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ggml/src/ggml-openvino/ggml-openvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,11 +1017,20 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
break;
}
case GGML_OP_GATED_DELTA_NET: {
if (ggml_openvino_get_device_name() == "GPU" && op->src[0]->ne[2] > 1) {
// CVS-186471
// if (ggml_openvino_get_device_name() == "GPU" && op->src[0]->ne[2] > 1) {
// // CVS-186471
// return true;
// }
if (op->src[0]->op == GGML_OP_PERMUTE) {
return true;
}
if (op->src[0]->op == GGML_OP_PERMUTE) {
// kda (per-key-dimension gating) not supported by fused GatedDeltaNet op
if (op->src[3]->ne[0] != 1) {
return true;
}
// v_repeat > 1 (GQA): ggml uses modulo head mapping (h_q = h_v % H_k)
// but the fused op uses consecutive mapping (h_q = h_v / group_size)
if (op->src[2]->ne[1] != op->src[0]->ne[1]) {
return true;
}
break;
Expand All @@ -1033,7 +1042,6 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
}

static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
// return true;
GGML_ASSERT(dev->reg != nullptr);

static std::set<ggml_type> supported_types{GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_I64,
Expand Down
57 changes: 57 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/gated_delta_net.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "gated_delta_net.hpp"

#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"
Expand Down Expand Up @@ -27,6 +29,61 @@ namespace ggml {
namespace op {

OutputVector translate_gated_delta_net(const NodeContext & context) {
auto v_shape = context.get_input_shape(2).to_shape(); // [B, T, H_v, S_v]
auto q_shape = context.get_input_shape(0).to_shape(); // [B, T, H_k, S_k]
auto g_shape = context.get_input_shape(3).to_shape(); // [B, T, H_v, 1 or S_v]

const bool kda = (g_shape[3] == v_shape[3]);

// Fused GatedDeltaNet op only supports scalar gate (kda=0).
// Fall back to reference implementation for per-key-dimension gating.
// if (kda) {
// return translate_gated_delta_net_ref(context);
// }

auto q = context.get_input(0);
auto k = context.get_input(1);
auto v = context.get_input(2);
auto g = context.get_input(3);
auto beta = context.get_input(4);
auto state = context.get_input(5);

const int64_t B = v_shape[0];
const int64_t T = v_shape[1];
const int64_t H_v = v_shape[2];
const int64_t S_v = v_shape[3];
const int64_t H_k = q_shape[2];
const int64_t S_k = q_shape[3];

// ggml state layout (OV notation): [B, H_v, value_dim, key_dim]
// GatedDeltaNet op expects: [B, H_v, key_dim, value_dim]
auto state_reshape_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, std::vector<int64_t>{B, H_v, S_v, S_k});
state = std::make_shared<ov::op::v1::Reshape>(state, state_reshape_shape, false);
auto state_perm = ov::op::v0::Constant::create(ov::element::i64, {4}, std::vector<int64_t>{0, 1, 3, 2});
state = std::make_shared<ov::op::v1::Transpose>(state, state_perm);

g = std::make_shared<ov::op::v0::Squeeze>(g, ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
beta = std::make_shared<ov::op::v0::Squeeze>(beta, ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));

auto gdn = std::make_shared<ov::op::internal::GatedDeltaNet>(q, k, v, state, g, beta);

auto attn_4d = gdn->output(0);
auto state_4d = gdn->output(1); // [B, H_v, key_dim, value_dim]
// Transpose output state back to ggml layout [B, H_v, value_dim, key_dim]
auto state_transposed = std::make_shared<ov::op::v1::Transpose>(state_4d, state_perm);
auto flat_shape_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto attn = std::make_shared<ov::op::v1::Reshape>(attn_4d, flat_shape_1d, false);
auto new_state = std::make_shared<ov::op::v1::Reshape>(state_transposed, flat_shape_1d, false);
auto packed = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{attn, new_state}, 0);
auto out_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, std::vector<int64_t>{1, 1, T * B + S_v * B, S_v * H_v});
auto res = std::make_shared<ov::op::v1::Reshape>(packed, out_shape, false);

return rename_outputs_with_suffix({res}, context.get_name());
}

OutputVector translate_gated_delta_net_ref(const NodeContext & context) {
num_inputs_check(context, 6, 6);

// Inputs (OV shapes are reversed from ggml):
Expand Down
65 changes: 65 additions & 0 deletions ggml/src/ggml-openvino/openvino/op/gated_delta_net.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#pragma once

#include "openvino/op/op.hpp"

namespace ov::op::internal {
/// \note GatedDeltaNet op class is under development and subject to change
///
/// \brief Operator performing Gated Delta Net computation
/// \ingroup ov_ops_cpp_api
class OPENVINO_API GatedDeltaNet : public ov::op::Op {
public:
OPENVINO_OP("GatedDeltaNet")

GatedDeltaNet() = default;
/// \brief Constructs a GatedDeltaNet operation.
///
/// \param query Query tensor input.
/// \param key Key tensor input.
/// \param value Value tensor input.
/// \param recurrent_state Initial recurrent state tensor.
/// \param gate Gate tensor controlling state decay/update.
/// \param beta Beta tensor scaling the delta update.
/// \param fuse_qk_l2norm Enables fusing q/k L2-normalization into this op.
/// \param q_l2_norm_eps Epsilon used for query L2-normalization when fusion is enabled.
/// \param k_l2_norm_eps Epsilon used for key L2-normalization when fusion is enabled.
GatedDeltaNet(const Output<Node>& query,
const Output<Node>& key,
const Output<Node>& value,
const Output<Node>& recurrent_state,
const Output<Node>& gate,
const Output<Node>& beta,
const bool fuse_qk_l2norm = false,
const float q_l2_norm_eps = 1e-6F,
const float k_l2_norm_eps = 1e-6F);

/// \brief Constructs a GatedDeltaNet operation from input vector.
///
/// \param args Input tensor vector in order: query, key, value, recurrent_state, gate, beta.
/// \param fuse_qk_l2norm Enables fusing q/k L2-normalization into this op.
/// \param q_l2_norm_eps Epsilon used for query L2-normalization when fusion is enabled.
/// \param k_l2_norm_eps Epsilon used for key L2-normalization when fusion is enabled.
GatedDeltaNet(const ov::OutputVector& args,
const bool fuse_qk_l2norm = false,
const float q_l2_norm_eps = 1e-6F,
const float k_l2_norm_eps = 1e-6F);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
bool get_fuse_qk_l2norm() const {
return m_fuse_qk_l2norm;
}
float get_q_l2_norm_eps() const {
return m_q_l2_norm_eps;
}
float get_k_l2_norm_eps() const {
return m_k_l2_norm_eps;
}

private:
bool m_fuse_qk_l2norm = false;
float m_q_l2_norm_eps = 1e-6F;
float m_k_l2_norm_eps = 1e-6F;
};

} // namespace ov::op::internal
Loading