From 86f8daacfe5e202b99ed396aa574a8b41a982048 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 20 Apr 2026 23:29:19 +0200 Subject: [PATCH] mtmd: correct get_n_pos / get_decoder_pos (#22175) --- tools/mtmd/mtmd.cpp | 82 ++++++++++++++++++++++++++++++++------------- tools/mtmd/mtmd.h | 10 +++--- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 35b4bba77..dfc29e909 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -33,10 +33,16 @@ struct mtmd_bitmap { bool is_audio = false; // true if the bitmap is audio }; +// position indexing for decoder model +enum mtmd_pos_type { + MTMD_POS_TYPE_NORMAL, // number of positions equals to number of tokens + MTMD_POS_TYPE_MROPE, // qwen-vl mrope style, each image takes max(t,h,w) position indexes +}; + struct mtmd_image_tokens { uint32_t nx; // number of tokens in x direction uint32_t ny; // number of tokens in y direction - bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position) + mtmd_pos_type pos = MTMD_POS_TYPE_NORMAL; uint32_t n_tokens() const { return nx * ny; } clip_image_f32_batch batch_f32; // preprocessed image patches std::string id; // optional user-defined ID, useful for KV cache tracking @@ -45,7 +51,7 @@ struct mtmd_image_tokens { return mtmd_image_tokens{ nx, ny, - use_mrope_pos, + pos, batch_f32.clone(), id }; @@ -131,7 +137,7 @@ struct mtmd_context { int n_threads; std::string media_marker; const int n_embd_text; - llama_rope_type decoder_rope; + mtmd_pos_type pos_type; // these are not token, but strings used to mark the beginning and end of image/audio embeddings std::string img_beg; @@ -168,8 +174,7 @@ struct mtmd_context { print_timings(ctx_params.print_timings), n_threads (ctx_params.n_threads), media_marker (ctx_params.media_marker), - n_embd_text (llama_model_n_embd_inp(text_model)), - decoder_rope (llama_model_rope_type(text_model)) + n_embd_text (llama_model_n_embd_inp(text_model)) { if (ctx_params.image_marker != nullptr) { throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead"); @@ -179,6 +184,22 @@ struct mtmd_context { throw std::runtime_error("media_marker must not be empty"); } + auto decoder_rope_type = llama_model_rope_type(text_model); + switch (decoder_rope_type) { + case LLAMA_ROPE_TYPE_NORM: + case LLAMA_ROPE_TYPE_NEOX: + { + pos_type = MTMD_POS_TYPE_NORMAL; + } break; + case LLAMA_ROPE_TYPE_MROPE: + case LLAMA_ROPE_TYPE_IMROPE: + { + pos_type = MTMD_POS_TYPE_MROPE; + } break; + default: + throw std::runtime_error(string_format("unsupported decoder rope type: %d\n", decoder_rope_type)); + } + clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, /* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type), @@ -779,12 +800,12 @@ struct mtmd_tokenizer { // for Qwen2VL, we need this information for M-RoPE decoding positions image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get()); image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get()); - image_tokens->use_mrope_pos = true; } else { // other models, we only need the total number of tokens image_tokens->nx = n_tokens; image_tokens->ny = 1; } + image_tokens->pos = ctx->pos_type; image_tokens->batch_f32 = std::move(batch_f32); image_tokens->id = bitmap->id; // optional @@ -1016,7 +1037,7 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { return ctx->image_embd_v.data(); } -bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk) { +bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk) { auto proj_type = ctx->proj_type_v(); if (chunk && chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { proj_type = ctx->proj_type_a(); @@ -1030,20 +1051,19 @@ bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chu } } -bool mtmd_decode_use_mrope(mtmd_context * ctx) { - return ctx->decoder_rope == LLAMA_ROPE_TYPE_MROPE - || ctx->decoder_rope == LLAMA_ROPE_TYPE_IMROPE; +bool mtmd_decode_use_mrope(const mtmd_context * ctx) { + return ctx->pos_type; } -bool mtmd_support_vision(mtmd_context * ctx) { +bool mtmd_support_vision(const mtmd_context * ctx) { return ctx->ctx_v != nullptr; } -bool mtmd_support_audio(mtmd_context * ctx) { +bool mtmd_support_audio(const mtmd_context * ctx) { return ctx->ctx_a != nullptr; } -int mtmd_get_audio_sample_rate(mtmd_context * ctx) { +int mtmd_get_audio_sample_rate(const mtmd_context * ctx) { if (!ctx->ctx_a) { return -1; } @@ -1238,12 +1258,24 @@ size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) { mtmd_decoder_pos mtmd_image_tokens_get_decoder_pos(const mtmd_image_tokens * image_tokens, llama_pos pos_0, size_t i) { mtmd_decoder_pos pos; - // M-RoPE logic - // TODO: support other types of position encoding if needed - pos.t = pos_0; - pos.x = pos_0 + (i % image_tokens->nx); - pos.y = pos_0 + (i / image_tokens->nx); - pos.z = 0; // unused for now + switch (image_tokens->pos) { + case MTMD_POS_TYPE_MROPE: + { + pos.t = pos_0; + pos.x = pos_0 + (i % image_tokens->nx); + pos.y = pos_0 + (i / image_tokens->nx); + pos.z = 0; // unused for now + } break; + case MTMD_POS_TYPE_NORMAL: + { + pos.t = pos_0 + i; + pos.x = pos_0 + i; + pos.y = pos_0 + i; + pos.z = pos_0 + i; + } break; + default: + GGML_ABORT("invalid position type"); + } return pos; } @@ -1252,12 +1284,14 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) { } llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) { - if (image_tokens->use_mrope_pos) { - // for M-RoPE, temporal dimension = max(t,h,w) - // t is omitted as we don't support video input - return std::max(image_tokens->nx, image_tokens->ny); + switch (image_tokens->pos) { + case MTMD_POS_TYPE_MROPE: + return std::max(image_tokens->nx, image_tokens->ny); + case MTMD_POS_TYPE_NORMAL: + return image_tokens->n_tokens(); + default: + GGML_ABORT("invalid position type"); } - return image_tokens->n_tokens(); } // test function diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 6e36cb8ec..e364174b8 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -112,20 +112,20 @@ MTMD_API void mtmd_free(mtmd_context * ctx); // whether we need to set non-causal mask before llama_decode // if chunk is nullptr, we assume the default case where chunk is an image chunk -MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx, const mtmd_input_chunk * chunk); +MTMD_API bool mtmd_decode_use_non_causal(const mtmd_context * ctx, const mtmd_input_chunk * chunk); // whether the current model use M-RoPE for llama_decode -MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx); +MTMD_API bool mtmd_decode_use_mrope(const mtmd_context * ctx); // whether the current model supports vision input -MTMD_API bool mtmd_support_vision(mtmd_context * ctx); +MTMD_API bool mtmd_support_vision(const mtmd_context * ctx); // whether the current model supports audio input -MTMD_API bool mtmd_support_audio(mtmd_context * ctx); +MTMD_API bool mtmd_support_audio(const mtmd_context * ctx); // get audio sample rate in Hz, for example 16000 for Whisper // return -1 if audio is not supported -MTMD_API int mtmd_get_audio_sample_rate(mtmd_context * ctx); +MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx); // mtmd_bitmap //