wyxxxcat commented on code in PR #61924:
URL: https://github.com/apache/doris/pull/61924#discussion_r3020940923
##########
be/src/exprs/function/ai/embed.h:
##########
@@ -33,7 +42,218 @@ class FunctionEmbed : public AIFunction<FunctionEmbed> {
return
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat32>()));
}
+ Status execute_with_adapter(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count, const TAIResource&
config,
+ std::shared_ptr<AIAdapter>& adapter) const {
+ if (arguments.size() != 2) {
+ return Status::InvalidArgument("Function EMBED expects 2
arguments, but got {}",
+ arguments.size());
+ }
+
+ PrimitiveType input_type =
+
remove_nullable(block.get_by_position(arguments[1]).type)->get_primitive_type();
+ if (input_type == PrimitiveType::TYPE_JSONB) {
+ return _execute_multimodal_embed(context, block, arguments,
result, input_rows_count,
+ config, adapter);
+ }
+ if (input_type == PrimitiveType::TYPE_STRING || input_type ==
PrimitiveType::TYPE_VARCHAR ||
+ input_type == PrimitiveType::TYPE_CHAR) {
+ return _execute_text_embed(context, block, arguments, result,
input_rows_count, config,
+ adapter);
+ }
+ return Status::InvalidArgument(
+ "Function EMBED expects the second argument to be STRING or
JSON, but got type {}",
+ block.get_by_position(arguments[1]).type->get_name());
+ }
+
static FunctionPtr create() { return std::make_shared<FunctionEmbed>(); }
+
+private:
+ Status _execute_text_embed(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t result,
+ size_t input_rows_count, const TAIResource&
config,
+ std::shared_ptr<AIAdapter>& adapter) const {
+ auto col_result = ColumnArray::create(
+ ColumnNullable::create(ColumnFloat32::create(),
ColumnUInt8::create()));
+
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ std::string prompt;
+ RETURN_IF_ERROR(build_prompt(block, arguments, i, prompt));
+
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(execute_single_request(prompt, float_result,
config, adapter, context));
+ _insert_embedding_result(*col_result, float_result);
+ }
+
+ block.replace_by_position(result, std::move(col_result));
+ return Status::OK();
+ }
+
+ Status _execute_multimodal_embed(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count, const
TAIResource& config,
+ std::shared_ptr<AIAdapter>& adapter)
const {
+ auto col_result = ColumnArray::create(
+ ColumnNullable::create(ColumnFloat32::create(),
ColumnUInt8::create()));
+
+ int64_t ttl_seconds = 3600;
+ QueryContext* query_ctx = context->state()->get_query_ctx();
+ if (query_ctx &&
query_ctx->query_options().__isset.file_presigned_url_ttl_seconds) {
+ ttl_seconds =
query_ctx->query_options().file_presigned_url_ttl_seconds;
+ if (ttl_seconds <= 0) {
+ ttl_seconds = 3600;
+ }
+ }
+
+ const ColumnWithTypeAndName& file_column =
block.get_by_position(arguments[1]);
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ rapidjson::Document file_input;
+ RETURN_IF_ERROR(_parse_file_input(file_column, i, file_input));
+
+ MultimodalType media_type;
+ RETURN_IF_ERROR(_infer_media_type(file_input, media_type));
+
+ std::string media_url;
+ RETURN_IF_ERROR(_resolve_media_url(file_input, ttl_seconds,
media_url));
+
+ std::string request_body;
+
RETURN_IF_ERROR(adapter->build_multimodal_embedding_request(media_type,
media_url,
+
request_body));
+
+ std::vector<float> float_result;
+ RETURN_IF_ERROR(execute_embedding_request(request_body,
float_result, config, adapter,
+ context));
+ _insert_embedding_result(*col_result, float_result);
+ }
+
+ block.replace_by_position(result, std::move(col_result));
+ return Status::OK();
+ }
+
+ static void _insert_embedding_result(ColumnArray& col_array,
+ const std::vector<float>&
float_result) {
+ auto& offsets = col_array.get_offsets();
+ auto& nested_nullable_col =
assert_cast<ColumnNullable&>(col_array.get_data());
+ auto& nested_col =
+
assert_cast<ColumnFloat32&>(*(nested_nullable_col.get_nested_column_ptr()));
+ nested_col.reserve(nested_col.size() + float_result.size());
+
+ size_t current_offset = nested_col.size();
+ nested_col.insert_many_raw_data(reinterpret_cast<const
char*>(float_result.data()),
+ float_result.size());
+ offsets.push_back(current_offset + float_result.size());
+ auto& null_map = nested_nullable_col.get_null_map_column();
+ null_map.insert_many_vals(0, float_result.size());
+ }
+
+ static bool _starts_with_ignore_case(std::string_view s, std::string_view
prefix) {
+ if (s.size() < prefix.size()) {
+ return false;
+ }
+ return std::equal(prefix.begin(), prefix.end(), s.begin(), [](char a,
char b) {
+ return std::tolower(static_cast<unsigned char>(a)) ==
+ std::tolower(static_cast<unsigned char>(b));
+ });
+ }
+
+ static Status _infer_media_type(const rapidjson::Value& file_input,
+ MultimodalType& media_type) {
+ std::string content_type;
+ RETURN_IF_ERROR(_get_required_string_field(file_input, "content_type",
content_type));
+
+ if (_starts_with_ignore_case(content_type, "image/")) {
+ media_type = MultimodalType::IMAGE;
+ return Status::OK();
+ }
+ if (_starts_with_ignore_case(content_type, "video/")) {
+ media_type = MultimodalType::VIDEO;
+ return Status::OK();
+ }
+ if (_starts_with_ignore_case(content_type, "audio/")) {
+ media_type = MultimodalType::AUDIO;
+ return Status::OK();
+ }
+ return Status::InvalidArgument("Unsupported content_type for EMBED:
{}", content_type);
+ }
+
+ // Parse the FILE-like JSONB argument into a JSON object for downstream
field reads.
+ static Status _parse_file_input(const ColumnWithTypeAndName& file_column,
size_t row_num,
+ rapidjson::Document& file_input) {
+ std::string file_json =
+
JsonbToJson::jsonb_to_json_string(file_column.column->get_data_at(row_num).data,
+
file_column.column->get_data_at(row_num).size);
+ file_input.Parse(file_json.c_str());
+ DORIS_CHECK(file_input.HasParseError() || file_input.IsObject());
+ return Status::OK();
+ }
+
+ // TODO(lzq): After support FILE type, We should use the interface
provided by FILE to get the fields
+ // replacing this function
+ static Status _get_required_string_field(const rapidjson::Value& obj,
const char* field_name,
+ std::string& value) {
+ auto iter = obj.FindMember(field_name);
+ if (iter == obj.MemberEnd() || !iter->value.IsString()) {
+ return Status::InvalidArgument(
+ "EMBED file json field '{}' is required and must be a
string", field_name);
+ }
+ value = iter->value.GetString();
+ if (value.empty()) {
+ return Status::InvalidArgument("EMBED file json field '{}' can not
be empty",
+ field_name);
+ }
+ return Status::OK();
+ }
+
+ static Status init_s3_client_conf_from_json(const rapidjson::Value&
file_input,
+ S3ClientConf& s3_client_conf) {
+ std::string endpoint;
+ RETURN_IF_ERROR(_get_required_string_field(file_input, "endpoint",
endpoint));
+ std::string region;
+ RETURN_IF_ERROR(_get_required_string_field(file_input, "region",
region));
+ std::string role_arn;
+ RETURN_IF_ERROR(_get_required_string_field(file_input, "role_arn",
role_arn));
+
+ auto external_id_iter = file_input.FindMember("external_id");
+ if (external_id_iter != file_input.MemberEnd()) {
+ DORIS_CHECK(external_id_iter->value.IsString());
+ s3_client_conf.external_id = external_id_iter->value.GetString();
+ }
+
+ s3_client_conf.endpoint = endpoint;
+ s3_client_conf.region = region;
+ s3_client_conf.role_arn = role_arn;
+
+ return Status::OK();
+ }
+
+ Status _resolve_media_url(const rapidjson::Value& file_input, int64_t
ttl_seconds,
Review Comment:
LGTM
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]