github-actions[bot] commented on code in PR #62147:
URL: https://github.com/apache/doris/pull/62147#discussion_r3077190946
##########
be/src/exprs/function/ai/ai_adapter.h:
##########
@@ -969,6 +1207,51 @@ class GeminiAdapter : public AIAdapter {
return Status::OK();
}
+ Status build_multimodal_embedding_request(MultimodalType media_type,
+ const std::string& media_url,
+ std::string& request_body) const
override {
+ const char* mime_type = nullptr;
+ switch (media_type) {
+ case MultimodalType::IMAGE:
+ mime_type = "image/png";
+ break;
+ case MultimodalType::AUDIO:
+ mime_type = "audio/mpeg";
+ break;
+ case MultimodalType::VIDEO:
+ mime_type = "video/mp4";
+ break;
+ }
+
+ rapidjson::Document doc;
+ doc.SetObject();
+ auto& allocator = doc.GetAllocator();
+
+ std::string model_name = _config.model_name;
+ if (!model_name.starts_with("models/")) {
+ model_name = "models/" + model_name;
+ }
+ doc.AddMember("model", rapidjson::Value(model_name.c_str(),
allocator), allocator);
Review Comment:
This hardcodes Gemini to one MIME per coarse media type (`image/png`,
`audio/mpeg`, `video/mp4`) and ignores the actual `content_type` supplied in
the FILE JSON. `_infer_media_type()` accepts any subtype with the same prefix,
so valid inputs like JPEG/WebP/WAV/etc. will be sent downstream with the wrong
MIME and can be rejected or misinterpreted. Please pass the original
`content_type` through to this request builder (or reject unsupported subtypes
before signing/sending the URL).
##########
be/src/exprs/function/ai/embed.h:
##########
@@ -33,7 +42,221 @@ class FunctionEmbed : public AIFunction<FunctionEmbed> {
return
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
}
+ Status execute_with_adapter(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count, const TAIResource&
config,
+ std::shared_ptr<AIAdapter>& adapter) const {
+ if (arguments.size() != 2) {
+ return Status::InvalidArgument("Function EMBED expects 2
arguments, but got {}",
+ arguments.size());
+ }
+
+ PrimitiveType input_type =
+
remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
+ if (input_type == PrimitiveType::TYPE_JSONB) {
+ return _execute_multimodal_embed(context, block, arguments,
result, input_rows_count,
+ config, adapter);
+ }
+ if (input_type == PrimitiveType::TYPE_STRING || input_type ==
PrimitiveType::TYPE_VARCHAR ||
+ input_type == PrimitiveType::TYPE_CHAR) {
+ return _execute_text_embed(context, block, arguments, result,
input_rows_count, config,
+ adapter);
+ }
+ return Status::InvalidArgument(
+ "Function EMBED expects the second argument to be STRING or
JSON, but got type {}",
+ block.get_by_position(arguments[1]).type->get_name());
+ }
+
static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
+
+private:
+ Status _execute_text_embed(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t result,
+ size_t input_rows_count, const TAIResource&
config,
+ std::shared_ptr<AIAdapter>& adapter) const {
+ auto col_result = ColumnArray::create(
+ ColumnNullable::create(ColumnFloat32::create(),
ColumnUInt8::create()));
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ std::string prompt;
+ RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
+
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(execute_single_request(prompt, float_result,
config, adapter, context));
+ _insert_embedding_result(*col_result, float_result);
+ }
+
+ block.replace_by_position(result, std::move(col_result));
+ return Status::OK();
+ }
+
+ Status _execute_multimodal_embed(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count, const
TAIResource& config,
+ std::shared_ptr<AIAdapter>& adapter)
const {
+ auto col_result = ColumnArray::create(
+ ColumnNullable::create(ColumnFloat32::create(),
ColumnUInt8::create()));
+
+ int64_t ttl_seconds = 3600;
+ QueryContext* query_ctx = context->state()->get_query_ctx();
+ if (query_ctx &&
query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
+ ttl_seconds =
query_ctx->query_options().file_presigned_url_ttl_seconds;
+ if (ttl_seconds <= 0) {
+ ttl_seconds = 3600;
+ }
+ }
+
+ const ColumnWithTypeAndName& file_column =
block.get_by_position(arguments[1]);
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ rapidjson::Document file_input;
+ RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
+
+ MultimodalType media_type;
+ RETURN_IF_ERROR(_infer_media_type(file_input, media_type));
+
+ std::string media_url;
+ RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds,
media_url));
+
+ std::string request_body;
+
RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type,
media_url,
+
request_body));
+
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(execute_embedding_request(request_body,
float_result, config, adapter,
+ context));
+ _insert_embedding_result(*col_result, float_result);
+ }
+
+ block.replace_by_position(result, std::move(col_result));
+ return Status::OK();
+ }
+
+ static void _insert_embedding_result(ColumnArray& col_array,
+ const std::vector<float>&
float_result) {
+ auto& offsets = col_array.get_offsets();
+ auto& nested_nullable_col =
assert_cast<ColumnNullable&>(col_array.get_data());
+ auto& nested_col =
+
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
+ nested_col.reserve(nested_col.size() + float_result.size());
+
+ size_t current_offset = nested_col.size();
+ nested_col.insert_many_raw_data(reinterpret_cast<const
char*>(float_result.data()),
+ float_result.size());
+ offsets.push_back(current_offset + float_result.size());
+ auto& null_map = nested_nullable_col.get_null_map_column();
+ null_map.insert_many_vals(0, float_result.size());
+ }
+
+ static bool _starts_with_ignore_case(std::string_view s, std::string_view
prefix) {
+ if (s.size() < prefix.size()) {
+ return false;
+ }
+ return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a,
char b) {
+ return std::tolower(static_cast<unsigned char>(a)) ==
+ std::tolower(static_cast<unsigned char>(b));
+ });
+ }
+
+ static Status _infer_media_type(const rapidjson::Value& file_input,
+ MultimodalType& media_type) {
+ std::string content_type;
+ RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type",
content_type));
+
+ if (_starts_with_ignore_case(content_type, "image/")) {
+ media_type = MultimodalType::IMAGE;
+ return Status::OK();
+ } else if (_starts_with_ignore_case(content_type, "video/")) {
+ media_type = MultimodalType::VIDEO;
+ return Status::OK();
+ } else if (_starts_with_ignore_case(content_type, "audio/")) {
+ media_type = MultimodalType::AUDIO;
+ return Status::OK();
+ }
+
+ return Status::InvalidArgument("Unsupported content_type for EMBED:
{}", content_type);
+ }
+
+ // Parse the FILE-like JSONB argument into a JSON object for downstream
field reads.
+ static Status _parse_file_input(const ColumnWithTypeAndName& file_column,
size_t row_num,
+ rapidjson::Document& file_input) {
+ std::string file_json =
+
JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
+
file_column.column->get_data_at(row_num).size);
+ file_input.Parse(file_json.c_str());
Review Comment:
`_parse_file_input()` is operating on user-provided `JSON`, but this uses
`DORIS_CHECK` after parsing. The FE signature accepts any JSON value here, so a
query like `SELECT EMBED('r', CAST('[]' AS JSON))` (or `CAST('null' AS JSON)`)
reaches this path and aborts the BE instead of returning a normal user-facing
error. This should validate and return
`Status::InvalidArgument`/`Status::InternalError`, not crash.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]