diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 72fc47fc81a..49c21f1b054 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -237,20 +237,17 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { const int mode = node->op_params[2]; switch (mode) { case GGML_ROPE_TYPE_NEOX: { - op_case = 0x00010000; + op_case = 1; break; } case GGML_ROPE_TYPE_IMROPE: { - op_case = 0x00020000; + op_case = 2; break; } default: - op_case = 0x00000000; + op_case = 0; break; } - if (node->src[0]->op == GGML_OP_VIEW) { - op_case = (op_case | 0x00000002); - } break; } case GGML_OP_VIEW: { diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 7bde5a2fd0c..91850a000b5 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -281,7 +281,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { } inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) { - return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE; + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE && op->src[1]->op == GGML_OP_NONE; } std::string get_graph_input_ov_name(const ggml_tensor * tensor, const ggml_tensor * op) { diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 263d733bd4a..de8bcdb38de 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -35,11 +35,10 @@ OutputVector translate_rope(const NodeContext & context) { ov::Output res; - auto data_node = process_view_input_new(context, 0).get_node_shared_ptr(); + auto data_node = context.get_input(0).get_node_shared_ptr(); auto output_shape = context.get_output_shape().to_shape(); int32_t * op_params = context.get_output_op_params(); - const int mode = (op_case & 0xFFFF0000) >> 16; - op_case = (op_case & 0x0000FFFF); + const int mode = op_case; constexpr int TYPE_NORMAL = 0; constexpr int TYPE_NEOX = 1; @@ -61,10 +60,8 @@ OutputVector translate_rope(const NodeContext & context) { cos_theta_node = sin_cos.second; } - if (op_case == 2) { - // The input comes from a VIEW - int slice_len = output_shape[2] * output_shape[3]; - data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr(); + if (context.get_view_input_size(0) > 0) { + data_node = process_view_input_new(context, 0).get_node_shared_ptr(); if (context.is_stateful()) { auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {3}, std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});