Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions ggml/src/ggml-openvino/ggml-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,17 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const {
const int mode = node->op_params[2];
switch (mode) {
case GGML_ROPE_TYPE_NEOX: {
op_case = 0x00010000;
op_case = 1;
break;
}
case GGML_ROPE_TYPE_IMROPE: {
op_case = 0x00020000;
op_case = 2;
break;
}
default:
op_case = 0x00000000;
op_case = 0;
break;
}
if (node->src[0]->op == GGML_OP_VIEW) {
op_case = (op_case | 0x00000002);
}
break;
}
case GGML_OP_VIEW: {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/ggml-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
}

inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) {
return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE;
return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE && op->src[1]->op == GGML_OP_NONE;
}

std::string get_graph_input_ov_name(const ggml_tensor * tensor, const ggml_tensor * op) {
Expand Down
11 changes: 4 additions & 7 deletions ggml/src/ggml-openvino/openvino/op/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ OutputVector translate_rope(const NodeContext & context) {

ov::Output<Node> res;

auto data_node = process_view_input_new(context, 0).get_node_shared_ptr();
auto data_node = context.get_input(0).get_node_shared_ptr();
auto output_shape = context.get_output_shape().to_shape();
int32_t * op_params = context.get_output_op_params();
const int mode = (op_case & 0xFFFF0000) >> 16;
op_case = (op_case & 0x0000FFFF);
const int mode = op_case;

constexpr int TYPE_NORMAL = 0;
constexpr int TYPE_NEOX = 1;
Expand All @@ -61,10 +60,8 @@ OutputVector translate_rope(const NodeContext & context) {
cos_theta_node = sin_cos.second;
}

if (op_case == 2) {
// The input comes from a VIEW
int slice_len = output_shape[2] * output_shape[3];
data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr();
if (context.get_view_input_size(0) > 0) {
data_node = process_view_input_new(context, 0).get_node_shared_ptr();
if (context.is_stateful()) {
auto data_shape = ov::op::v0::Constant::create(
ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
Expand Down
Loading