github-actions[bot] commented on code in PR #62147:
URL: https://github.com/apache/doris/pull/62147#discussion_r3077190946


##########
be/src/exprs/function/ai/ai_adapter.h:
##########
@@ -969,6 +1207,51 @@ class GeminiAdapter : public AIAdapter {
         return Status::OK();
     }
 
+    Status build_multimodal_embedding_request(MultimodalType media_type,
+                                              const std::string& media_url,
+                                              std::string& request_body) const 
override {
+        const char* mime_type = nullptr;
+        switch (media_type) {
+        case MultimodalType::IMAGE:
+            mime_type = "image/png";
+            break;
+        case MultimodalType::AUDIO:
+            mime_type = "audio/mpeg";
+            break;
+        case MultimodalType::VIDEO:
+            mime_type = "video/mp4";
+            break;
+        }
+
+        rapidjson::Document doc;
+        doc.SetObject();
+        auto& allocator = doc.GetAllocator();
+
+        std::string model_name = _config.model_name;
+        if (!model_name.starts_with("models/")) {
+            model_name = "models/" + model_name;
+        }
+        doc.AddMember("model", rapidjson::Value(model_name.c_str(), 
allocator), allocator);

Review Comment:
   This hardcodes Gemini to one MIME per coarse media type (`image/png`, 
`audio/mpeg`, `video/mp4`) and ignores the actual `content_type` supplied in 
the FILE JSON. `_infer_media_type()` accepts any subtype with the same prefix, 
so valid inputs like JPEG/WebP/WAV/etc. will be sent downstream with the wrong 
MIME and can be rejected or misinterpreted. Please pass the original 
`content_type` through to this request builder (or reject unsupported subtypes 
before signing/sending the URL).



##########
be/src/exprs/function/ai/embed.h:
##########
@@ -33,7 +42,221 @@ class FunctionEmbed : public AIFunction<FunctionEmbed> {
         return 
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
     }
 
+    Status execute_with_adapter(FunctionContext* context, Block& block,
+                                const ColumnNumbers& arguments, uint32_t 
result,
+                                size_t input_rows_count, const TAIResource& 
config,
+                                std::shared_ptr<AIAdapter>& adapter) const {
+        if (arguments.size() != 2) {
+            return Status::InvalidArgument("Function EMBED expects 2 
arguments, but got {}",
+                                           arguments.size());
+        }
+
+        PrimitiveType input_type =
+                
remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
+        if (input_type == PrimitiveType::TYPE_JSONB) {
+            return _execute_multimodal_embed(context, block, arguments, 
result, input_rows_count,
+                                             config, adapter);
+        }
+        if (input_type == PrimitiveType::TYPE_STRING || input_type == 
PrimitiveType::TYPE_VARCHAR ||
+            input_type == PrimitiveType::TYPE_CHAR) {
+            return _execute_text_embed(context, block, arguments, result, 
input_rows_count, config,
+                                       adapter);
+        }
+        return Status::InvalidArgument(
+                "Function EMBED expects the second argument to be STRING or 
JSON, but got type {}",
+                block.get_by_position(arguments[1]).type->get_name());
+    }
+
     static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
+
+private:
+    Status _execute_text_embed(FunctionContext* context, Block& block,
+                               const ColumnNumbers& arguments, uint32_t result,
+                               size_t input_rows_count, const TAIResource& 
config,
+                               std::shared_ptr<AIAdapter>& adapter) const {
+        auto col_result = ColumnArray::create(
+                ColumnNullable::create(ColumnFloat32::create(), 
ColumnUInt8::create()));
+
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            std::string prompt;
+            RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
+
+            std::vector<float> float_result;
+            RETURN_IF_ERROR(execute_single_request(prompt, float_result, 
config, adapter, context));
+            _insert_embedding_result(*col_result, float_result);
+        }
+
+        block.replace_by_position(result, std::move(col_result));
+        return Status::OK();
+    }
+
+    Status _execute_multimodal_embed(FunctionContext* context, Block& block,
+                                     const ColumnNumbers& arguments, uint32_t 
result,
+                                     size_t input_rows_count, const 
TAIResource& config,
+                                     std::shared_ptr<AIAdapter>& adapter) 
const {
+        auto col_result = ColumnArray::create(
+                ColumnNullable::create(ColumnFloat32::create(), 
ColumnUInt8::create()));
+
+        int64_t ttl_seconds = 3600;
+        QueryContext* query_ctx = context->state()->get_query_ctx();
+        if (query_ctx && 
query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
+            ttl_seconds = 
query_ctx->query_options().file_presigned_url_ttl_seconds;
+            if (ttl_seconds <= 0) {
+                ttl_seconds = 3600;
+            }
+        }
+
+        const ColumnWithTypeAndName& file_column = 
block.get_by_position(arguments[1]);
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            rapidjson::Document file_input;
+            RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
+
+            MultimodalType media_type;
+            RETURN_IF_ERROR(_infer_media_type(file_input, media_type));
+
+            std::string media_url;
+            RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds, 
media_url));
+
+            std::string request_body;
+            
RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type, 
media_url,
+                                                                        
request_body));
+
+            std::vector<float> float_result;
+            RETURN_IF_ERROR(execute_embedding_request(request_body, 
float_result, config, adapter,
+                                                      context));
+            _insert_embedding_result(*col_result, float_result);
+        }
+
+        block.replace_by_position(result, std::move(col_result));
+        return Status::OK();
+    }
+
+    static void _insert_embedding_result(ColumnArray& col_array,
+                                         const std::vector<float>& 
float_result) {
+        auto& offsets = col_array.get_offsets();
+        auto& nested_nullable_col = 
assert_cast<ColumnNullable&>(col_array.get_data());
+        auto& nested_col =
+                
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
+        nested_col.reserve(nested_col.size() + float_result.size());
+
+        size_t current_offset = nested_col.size();
+        nested_col.insert_many_raw_data(reinterpret_cast<const 
char*>(float_result.data()),
+                                        float_result.size());
+        offsets.push_back(current_offset + float_result.size());
+        auto& null_map = nested_nullable_col.get_null_map_column();
+        null_map.insert_many_vals(0, float_result.size());
+    }
+
+    static bool _starts_with_ignore_case(std::string_view s, std::string_view 
prefix) {
+        if (s.size() < prefix.size()) {
+            return false;
+        }
+        return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a, 
char b) {
+            return std::tolower(static_cast<unsigned char>(a)) ==
+                   std::tolower(static_cast<unsigned char>(b));
+        });
+    }
+
+    static Status _infer_media_type(const rapidjson::Value& file_input,
+                                    MultimodalType& media_type) {
+        std::string content_type;
+        RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type", 
content_type));
+
+        if (_starts_with_ignore_case(content_type, "image/")) {
+            media_type = MultimodalType::IMAGE;
+            return Status::OK();
+        } else if (_starts_with_ignore_case(content_type, "video/")) {
+            media_type = MultimodalType::VIDEO;
+            return Status::OK();
+        } else if (_starts_with_ignore_case(content_type, "audio/")) {
+            media_type = MultimodalType::AUDIO;
+            return Status::OK();
+        }
+
+        return Status::InvalidArgument("Unsupported content_type for EMBED: 
{}", content_type);
+    }
+
+    // Parse the FILE-like JSONB argument into a JSON object for downstream 
field reads.
+    static Status _parse_file_input(const ColumnWithTypeAndName& file_column, 
size_t row_num,
+                                    rapidjson::Document& file_input) {
+        std::string file_json =
+                
JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
+                                                  
file_column.column->get_data_at(row_num).size);
+        file_input.Parse(file_json.c_str());

Review Comment:
   `_parse_file_input()` is operating on user-provided `JSON`, but this uses 
`DORIS_CHECK` after parsing. The FE signature accepts any JSON value here, so a 
query like `SELECT EMBED('r', CAST('[]' AS JSON))` (or `CAST('null' AS JSON)`) 
reaches this path and aborts the BE instead of returning a normal user-facing 
error. This should validate and return 
`Status::InvalidArgument`/`Status::InternalError`, not crash.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to