This is an automated email from the ASF dual-hosted git repository.
corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 12414c4eab [Feature][Transform-V2] Support multimodal embeddings
(#9673)
12414c4eab is described below
commit 12414c4eab869bf779193e2815db23f1d28a446e
Author: xiaochen <[email protected]>
AuthorDate: Mon Sep 8 18:34:55 2025 +0800
[Feature][Transform-V2] Support multimodal embeddings (#9673)
---
.../seatunnel-engine/hybrid-cluster-deployment.md | 2 +-
.../separated-cluster-deployment.md | 2 +-
docs/en/transform-v2/embedding.md | 303 ++++++++++++++++-
.../separated-cluster-deployment.md | 2 +-
docs/zh/transform-v2/embedding.md | 300 ++++++++++++++++-
.../seatunnel/api/table/type/CommonOptions.java | 12 +-
.../seatunnel/api/table/type/MetadataUtil.java | 33 ++
.../file/sink/writer/BinaryWriteStrategy.java | 5 +-
.../file/source/reader/BinaryReadStrategy.java | 9 +
.../file/reader/BinaryReadStrategyTest.java | 8 +-
.../binary/local_file_binary_to_assert.conf | 2 +-
.../seatunnel/e2e/transform/TestEmbeddingIT.java | 33 ++
.../src/test/resources/binary/cat.png | Bin 0 -> 1969877 bytes
.../test/resources/embedding_transform_binary.conf | 85 +++++
.../embedding_transform_binary_complete_file.conf | 84 +++++
.../resources/embedding_transform_multimodal.conf | 243 +++++++++++++
.../src/test/resources/mock-embedding.json | 39 +++
.../checkpoint/CheckpointErrorRestoreEndTest.java | 4 +-
.../common/MultipleFieldOutputTransform.java | 5 +
.../transform/nlpmodel/ModelProvider.java | 16 +-
.../nlpmodel/embedding/EmbeddingTransform.java | 189 +++++++++--
.../embedding/EmbeddingTransformConfig.java | 12 +-
.../transform/nlpmodel/embedding/FieldSpec.java | 119 +++++++
.../embedding/multimodal/ModalityType.java | 110 ++++++
.../embedding/multimodal/MultimodalFieldValue.java | 70 ++++
.../embedding/multimodal/MultimodalModel.java | 60 ++++
.../embedding/multimodal/PayloadFormat.java | 57 ++++
.../nlpmodel/embedding/remote/AbstractModel.java | 2 +-
.../embedding/remote/doubao/DoubaoModel.java | 159 ++++++++-
.../embedding/DoubaoMultimodalModelTest.java | 314 +++++++++++++++++
.../embedding/EmbeddingModelDimensionTest.java | 1 +
.../transform/embedding/FieldSpecTest.java | 114 +++++++
.../transform/embedding/MultimodalConfigTest.java | 375 +++++++++++++++++++++
33 files changed, 2708 insertions(+), 61 deletions(-)
diff --git a/docs/en/seatunnel-engine/hybrid-cluster-deployment.md
b/docs/en/seatunnel-engine/hybrid-cluster-deployment.md
index d63219b2f0..7e09112589 100644
--- a/docs/en/seatunnel-engine/hybrid-cluster-deployment.md
+++ b/docs/en/seatunnel-engine/hybrid-cluster-deployment.md
@@ -359,4 +359,4 @@ Now that the cluster is deployed, you can complete the
submission and management
### 8.2 Submit Jobs With The REST API
-The SeaTunnel Engine provides a REST API for submitting and managing jobs. For
more information, please refer to [REST API V2](rest-api-v2.md)
+The SeaTunnel Engine provides a REST API for submitting and managing jobs. For
more information, please refer to [REST API V2](rest-api-v2.md)
\ No newline at end of file
diff --git a/docs/en/seatunnel-engine/separated-cluster-deployment.md
b/docs/en/seatunnel-engine/separated-cluster-deployment.md
index d01043edef..4afbc033ed 100644
--- a/docs/en/seatunnel-engine/separated-cluster-deployment.md
+++ b/docs/en/seatunnel-engine/separated-cluster-deployment.md
@@ -472,4 +472,4 @@ Now that the cluster has been deployed, you can complete
the job submission and
### 8.2 Submit Jobs With The REST API
-The SeaTunnel Engine provides a REST API for submitting and managing jobs. For
more information, please refer to [REST API V2](rest-api-v2.md)
+The SeaTunnel Engine provides a REST API for submitting and managing jobs. For
more information, please refer to [REST API V2](rest-api-v2.md)
\ No newline at end of file
diff --git a/docs/en/transform-v2/embedding.md
b/docs/en/transform-v2/embedding.md
index 819285cd90..5c5bd4bf1c 100644
--- a/docs/en/transform-v2/embedding.md
+++ b/docs/en/transform-v2/embedding.md
@@ -4,8 +4,8 @@
## Description
-The `Embedding` transform plugin leverages embedding models to convert text
data into vectorized representations. This
-transformation can be applied to various fields. The plugin supports multiple
model providers and can be integrated with
+The `Embedding` transform plugin leverages embedding models to convert text
and multimodal data into vectorized representations. This
+transformation can be applied to various fields including text, images, and
videos. The plugin supports multiple model providers and can be integrated with
different API endpoints.
> **Important Note:** The current embedding precision only supports float32
> format.
@@ -59,15 +59,69 @@ capacity and the model provider's API limitations.
### vectorization_fields
A mapping between input fields and their respective output vector fields. This
allows the plugin to understand which
-text fields to vectorize and how to store the resulting vectors.
+fields to vectorize and how to store the resulting vectors. The plugin
supports multimodal data by allowing you to specify
+the modality type for each field.
+**Basic Text Vectorization:**
```hocon
vectorization_fields {
book_intro_vector = book_intro
- author_biography_vector = author_biography
+ author_biography_vector = author_biography
}
```
+**Multimodal Vectorization:**
+```hocon
+vectorization_fields {
+ # Basic text field
+ text_vector = text_field
+
+ # Explicit modality type configuration
+ product_image_vector = {
+ field = product_image_url
+ modality = jpeg
+ format = url
+ }
+
+ # Auto-detect modality type (based on file suffix)
+ thumbnail_vector = {
+ field = thumbnail_image # If value is "image.png", auto-detects as
PNG modality
+ format = url
+ }
+
+ # Video field configuration
+ demo_video_vector = {
+ field = product_video_url
+ modality = mp4
+ format = url
+ }
+
+ # Binary data configuration
+ binary_image_vector = {
+ field = image_data
+ modality = jpeg
+ format = binary
+ }
+}
+```
+
+**Field Specification Formats:**
+
+**Supported Modality Types:**
+- **Images:** `jpeg` (jpg, jpeg), `png` (png, apng), `gif`, `webp`, `bmp`
(bmp, dib), `tiff` (tiff, tif), `ico`, `icns`, `sgi`, `jpeg2000` (j2c, j2k,
jp2, jpc, jpf, jpx)
+- **Videos:** `mp4`, `avi`, `mov`
+- **Text:** `text` (default)
+
+**Payload Formats:**
+- `text` - Text format (default)
+- `url` - URL format
+- `binary` - Binary data format
+
+**Automatic Modality Detection:**
+When `modality` is not explicitly specified and `format` is not `binary`, the
system automatically detects the modality type based on the file suffix of the
field value:
+
+> **Important:** When using multimodal fields (image or video), ensure your
model provider supports multimodal embedding. Image and video fields must
contain valid URLs or binary data. Currently, `DOUBAO` provider supports
multimodal data processing.
+
### model
The specific embedding model to use. This depends on the
`embedding_model_provider`. For example, if using OPENAI, you
@@ -137,7 +191,9 @@ The `custom_request_body` option supports placeholders:
Transform plugin common parameters, please refer to [Transform
Plugin](common-options.md) for details.
-## Example Configuration
+## Example Configurations
+
+### Basic Text Embedding
```hocon
env {
@@ -263,6 +319,243 @@ sink {
}
```
+### Multimodal Embedding (Volcengine Doubao)
+
+Multimodal Embedding supports input as accessible URL or Binary data formats
to process multimodal data.
+
+#### URL
+
+```hocon
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ product_name = "string"
+ description = "string"
+ product_image_url = "string"
+ product_video_url = "string"
+ thumbnail_image = "string"
+ promotional_video = "string"
+ category = "string"
+ price = "decimal(10,2)"
+ created_at = "timestamp"
+ }
+ }
+ rows = [
+ {
+ fields = [
+ 1,
+ "iPhone 15 Pro",
+ "Latest iPhone with advanced camera system and A17 Pro chip",
+ "https://example.com/images/iphone15pro.jpg",
+ "https://example.com/videos/iphone15pro_demo.mp4",
+ "https://example.com/thumbnails/iphone15pro_thumb.png",
+ "https://example.com/videos/iphone15pro_promo.mov",
+ "Electronics",
+ 999.99,
+ "2024-01-15T10:30:00"
+ ],
+ kind = INSERT
+ },
+ {
+ fields = [
+ 2,
+ "MacBook Air M3",
+ "Ultra-thin laptop with M3 chip for incredible performance",
+ "https://example.com/images/macbook_air_m3.jpeg",
+ "https://example.com/videos/macbook_air_review.avi",
+ "https://example.com/thumbnails/macbook_thumb.webp",
+ "https://example.com/videos/macbook_commercial.mp4",
+ "Computers",
+ 1299.99,
+ "2024-02-20T14:15:00"
+ ],
+ kind = INSERT
+ }
+ ]
+ plugin_output = "fake"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "fake"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision"
+ api_key = "your-api-key"
+ api_path = "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields {
+ # Text field - defaults to text modality
+ description_vector = description
+
+ product_image_vector = {
+ field = product_image_url
+ modality = jpeg
+ format = url
+ }
+
+ thumbnail_vector = {
+ field = thumbnail_image # If value is "thumb.png", auto-detects as PNG
+ format = url
+ }
+
+ demo_video_vector = {
+ field = product_video_url
+ modality = mp4
+ format = url
+ }
+
+ promo_video_vector = {
+ field = promotional_video # If value is "promo.mov", auto-detects as
MOV
+ format = url
+ }
+
+ # Mixed content - product name
+ product_name_vector = product_name
+ }
+
+ plugin_output = "multimodal_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "multimodal_embedding_output"
+ rules = {
+ field_rules = [
+ {
+ field_name = id
+ field_type = int
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = description_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = product_image_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = thumbnail_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = demo_video_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
+```
+
+#### Binary
+
+```hocon
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ LocalFile {
+ path = "/seatunnel/read/binary/"
+ file_format_type = "binary"
+ binary_complete_file_mode = false
+ binary_chunk_size = 1024
+ plugin_output = "binary_source"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "binary_source"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision-250615"
+ api_key = "test-api-key"
+ api_path = "http://mockserver:1080/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields = {
+ image_embedding = {
+ field = "data"
+ modality = "jpeg"
+ format = "binary"
+ }
+ }
+
+ plugin_output = "binary_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "binary_embedding_output"
+ rules = {
+ row_rules = [
+ {
+ rule_type = MAX_ROW
+ rule_value = 1
+ }
+ ],
+ field_rules = [
+ {
+ field_name = image_embedding
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = relativePath
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
+```
+
+
### Customize the embedding model
```hocon
diff --git a/docs/zh/seatunnel-engine/separated-cluster-deployment.md
b/docs/zh/seatunnel-engine/separated-cluster-deployment.md
index 0bbfd12a47..8fa11e9174 100644
--- a/docs/zh/seatunnel-engine/separated-cluster-deployment.md
+++ b/docs/zh/seatunnel-engine/separated-cluster-deployment.md
@@ -486,4 +486,4 @@ hazelcast-client:
### 8.2 使用 REST API 提交作业
-SeaTunnel Engine 提供了 REST API 用于提交作业。有关详细信息,请参阅 [REST API V2](rest-api-v2.md)
+SeaTunnel Engine 提供了 REST API 用于提交作业。有关详细信息,请参阅 [REST API V2](rest-api-v2.md)
\ No newline at end of file
diff --git a/docs/zh/transform-v2/embedding.md
b/docs/zh/transform-v2/embedding.md
index ccbdd66821..31ba604600 100644
--- a/docs/zh/transform-v2/embedding.md
+++ b/docs/zh/transform-v2/embedding.md
@@ -4,9 +4,9 @@
## 描述
-`Embedding` 转换插件利用 embedding
模型将文本数据转换为向量化表示。此转换可以应用于各种字段。该插件支持多种模型提供商,并且可以与不同的API集成。
+`Embedding` 转换插件利用 embedding
模型将文本和多模态数据转换为向量化表示。此转换可以应用于各种字段,包括文本、图片和视频。该插件支持多种模型提供商,并且可以与不同的API集成。
-> **重要提示:** 当前 embedding 精确度仅支持 float32
+> **重要提示:** 当前 embedding 精确度仅支持 float32
## 配置选项
@@ -53,15 +53,68 @@
### vectorization_fields
-输入字段和相应的输出向量字段之间的映射。这使得插件可以理解要向量化的文本字段以及如何存储生成的向量。
+输入字段和相应的输出向量字段之间的映射。这使得插件可以理解要向量化的字段以及如何存储生成的向量。插件通过允许您为每个字段指定模态类型来支持多模态数据。
+**基本文本向量化:**
```hocon
vectorization_fields {
book_intro_vector = book_intro
- author_biography_vector = author_biography
+ author_biography_vector = author_biography
}
```
+**多模态向量化:**
+```hocon
+vectorization_fields {
+ # 基本文本字段
+ text_vector = text_field
+
+ # 显式指定模态类型的配置
+ product_image_vector = {
+ field = product_image_url
+ modality = jpeg
+ format = url
+ }
+
+ # 自动检测模态类型(根据文件后缀)
+ thumbnail_vector = {
+ field = thumbnail_image # 如果值为 "image.png",会自动检测为 PNG 模态
+ format = url
+ }
+
+ # 视频字段配置
+ demo_video_vector = {
+ field = product_video_url
+ modality = mp4
+ format = url
+ }
+
+ # 二进制数据配置
+ binary_image_vector = {
+ field = image_data
+ modality = jpeg
+ format = binary
+ }
+}
+```
+
+**字段规范格式:**
+
+**支持的模态类型:**
+- **图片:** `jpeg` (jpg, jpeg), `png` (png, apng), `gif`, `webp`, `bmp` (bmp,
dib), `tiff` (tiff, tif), `ico`, `icns`, `sgi`, `jpeg2000` (j2c, j2k, jp2, jpc,
jpf, jpx)
+- **视频:** `mp4`, `avi`, `mov`
+- **文本:** `text`(默认)
+
+**数据格式:**
+- `text` - 文本格式(默认)
+- `url` - URL 格式
+- `binary` - 二进制数据格式
+
+**自动模态检测:**
+当未显式指定 `modality` 且 `format` 不是 `binary` 时,系统会根据字段值的文件后缀自动检测模态类型:
+
+> **重要:** 使用多模态字段(图片或视频)时,请确保您的模型提供商支持多模态 embedding。图片和视频字段必须包含有效的 URL
或二进制数据。目前,`DOUBAO` 提供商支持多模态数据处理。
+
### model
要使用的具体 embedding 模型。这取决于`embedding_model_provider`。例如,如果使用 OPENAI ,可以指定
`text-embedding-3-small`。
@@ -127,6 +180,8 @@ vectorization_fields {
## 示例配置
+### 基本文本 Embedding
+
```hocon
env {
job.mode = "BATCH"
@@ -253,6 +308,243 @@ sink {
}
```
+### 多模态 Embedding(火山引擎豆包)
+
+多模态 Embedding 支持输入可访问 URL 或 二进制数据格式处理多模态数据
+
+#### 可访问 URL
+
+```hocon
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ product_name = "string"
+ description = "string"
+ product_image_url = "string"
+ product_video_url = "string"
+ thumbnail_image = "string"
+ promotional_video = "string"
+ category = "string"
+ price = "decimal(10,2)"
+ created_at = "timestamp"
+ }
+ }
+ rows = [
+ {
+ fields = [
+ 1,
+ "iPhone 15 Pro",
+ "Latest iPhone with advanced camera system and A17 Pro chip",
+ "https://example.com/images/iphone15pro.jpg",
+ "https://example.com/videos/iphone15pro_demo.mp4",
+ "https://example.com/thumbnails/iphone15pro_thumb.png",
+ "https://example.com/videos/iphone15pro_promo.mov",
+ "Electronics",
+ 999.99,
+ "2024-01-15T10:30:00"
+ ],
+ kind = INSERT
+ },
+ {
+ fields = [
+ 2,
+ "MacBook Air M3",
+ "Ultra-thin laptop with M3 chip for incredible performance",
+ "https://example.com/images/macbook_air_m3.jpeg",
+ "https://example.com/videos/macbook_air_review.avi",
+ "https://example.com/thumbnails/macbook_thumb.webp",
+ "https://example.com/videos/macbook_commercial.mp4",
+ "Computers",
+ 1299.99,
+ "2024-02-20T14:15:00"
+ ],
+ kind = INSERT
+ }
+ ]
+ plugin_output = "fake"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "fake"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision"
+ api_key = "your-api-key"
+ api_path = "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields {
+ # 文本字段 - 默认文本模态
+ description_vector = description
+
+ # 显式指定图片模态
+ product_image_vector = {
+ field = product_image_url
+ modality = jpeg
+ format = url
+ }
+
+ thumbnail_vector = {
+ field = thumbnail_image
+ format = url
+ }
+
+ # 视频字段
+ demo_video_vector = {
+ field = product_video_url
+ modality = mp4
+ format = url
+ }
+
+ promo_video_vector = {
+ field = promotional_video # 如果值为 "promo.mov",自动检测为 MOV
+ format = url
+ }
+
+ product_name_vector = product_name
+ }
+
+ plugin_output = "multimodal_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "multimodal_embedding_output"
+ rules = {
+ field_rules = [
+ {
+ field_name = id
+ field_type = int
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = description_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = product_image_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = thumbnail_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = demo_video_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
+```
+
+#### 二进制格式
+
+```hocon
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ LocalFile {
+ path = "/seatunnel/read/binary/"
+ file_format_type = "binary"
+ binary_complete_file_mode = false
+ binary_chunk_size = 1024
+ plugin_output = "binary_source"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "binary_source"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision-250615"
+ api_key = "test-api-key"
+ api_path = "http://mockserver:1080/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields = {
+ image_embedding = {
+ field = "data"
+ modality = "jpeg"
+ format = "binary"
+ }
+ }
+
+ plugin_output = "binary_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "binary_embedding_output"
+ rules = {
+ row_rules = [
+ {
+ rule_type = MAX_ROW
+ rule_value = 1
+ }
+ ],
+ field_rules = [
+ {
+ field_name = image_embedding
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = relativePath
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
+```
+
### Customize the embedding model
```hocon
diff --git
a/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/CommonOptions.java
b/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/CommonOptions.java
index d12810cad7..927253609f 100644
---
a/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/CommonOptions.java
+++
b/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/CommonOptions.java
@@ -56,7 +56,17 @@ public enum CommonOptions {
* The key of {@link SeaTunnelRow#getOptions()} to store the DELAY value
of the row value. And
* the data should be milliseconds.
*/
- DELAY("Delay", true);
+ DELAY("Delay", true),
+ /**
+ * The key of {@link SeaTunnelRow#getOptions()} to indicate whether the
row represents a
+ * complete file.
+ */
+ IS_COMPLETE("is_complete", true),
+ /**
+ * The key of {@link SeaTunnelRow#getOptions()} to indicate whether the
row contains binary
+ * format data.
+ */
+ IS_BINARY_FORMAT("is_binary_format", true);
private final String name;
private final boolean supportMetadataTrans;
diff --git
a/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/MetadataUtil.java
b/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/MetadataUtil.java
index ed5fb4d615..4631126f72 100644
---
a/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/MetadataUtil.java
+++
b/seatunnel-api/src/main/java/org/apache/seatunnel/api/table/type/MetadataUtil.java
@@ -25,6 +25,8 @@ import java.util.stream.Stream;
import static org.apache.seatunnel.api.table.type.CommonOptions.DELAY;
import static org.apache.seatunnel.api.table.type.CommonOptions.EVENT_TIME;
+import static
org.apache.seatunnel.api.table.type.CommonOptions.IS_BINARY_FORMAT;
+import static org.apache.seatunnel.api.table.type.CommonOptions.IS_COMPLETE;
import static org.apache.seatunnel.api.table.type.CommonOptions.PARTITION;
public class MetadataUtil {
@@ -51,6 +53,22 @@ public class MetadataUtil {
row.getOptions().put(EVENT_TIME.getName(), delay);
}
+ public static void setBinaryRowComplete(SeaTunnelRow row) {
+ row.getOptions().put(IS_COMPLETE.getName(), true);
+ }
+
+ public static void setBinaryFormat(SeaTunnelRow row) {
+ row.getOptions().put(IS_BINARY_FORMAT.getName(), true);
+ }
+
+ public static boolean isComplete(Object row) {
+ return checkOption(row, IS_COMPLETE.getName(), false);
+ }
+
+ public static boolean isBinaryFormat(Object row) {
+ return checkOption(row, IS_BINARY_FORMAT.getName(), false);
+ }
+
public static String getDatabase(SeaTunnelRowAccessor row) {
if (row.getTableId() == null) {
return null;
@@ -76,4 +94,19 @@ public class MetadataUtil {
public static boolean isMetadataField(String fieldName) {
return METADATA_FIELDS.contains(fieldName);
}
+
+ public static <T> boolean checkOption(T row, String optionKey, boolean
defaultValue) {
+ if (row instanceof SeaTunnelRow) {
+ return ((SeaTunnelRow) row)
+ .getOptions()
+ .getOrDefault(optionKey, defaultValue)
+ .equals(true);
+ } else if (row instanceof SeaTunnelRowAccessor) {
+ return ((SeaTunnelRowAccessor) row)
+ .getOptions()
+ .getOrDefault(optionKey, defaultValue)
+ .equals(true);
+ }
+ throw new IllegalArgumentException("Unsupported row type: " +
row.getClass().getName());
+ }
}
diff --git
a/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/sink/writer/BinaryWriteStrategy.java
b/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/sink/writer/BinaryWriteStrategy.java
index db3f0c1fc2..1fe307e33d 100644
---
a/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/sink/writer/BinaryWriteStrategy.java
+++
b/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/sink/writer/BinaryWriteStrategy.java
@@ -62,9 +62,12 @@ public class BinaryWriteStrategy extends
AbstractWriteStrategy<FSDataOutputStrea
@Override
public void write(SeaTunnelRow seaTunnelRow) throws FileConnectorException
{
+ long partIndex = (long) seaTunnelRow.getField(2);
+ if (partIndex == -1) {
+ return;
+ }
byte[] data = (byte[]) seaTunnelRow.getField(0);
String relativePath = (String) seaTunnelRow.getField(1);
- long partIndex = (long) seaTunnelRow.getField(2);
String filePath = getOrCreateFilePathBeingWritten(relativePath);
FSDataOutputStream fsDataOutputStream =
getOrCreateOutputStream(filePath);
if (partIndex - 1 != partIndexMap.get(filePath)) {
diff --git
a/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/source/reader/BinaryReadStrategy.java
b/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/source/reader/BinaryReadStrategy.java
index 897c37a344..af6431646a 100644
---
a/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/source/reader/BinaryReadStrategy.java
+++
b/seatunnel-connectors-v2/connector-file/connector-file-base/src/main/java/org/apache/seatunnel/connectors/seatunnel/file/source/reader/BinaryReadStrategy.java
@@ -19,6 +19,7 @@ package
org.apache.seatunnel.connectors.seatunnel.file.source.reader;
import org.apache.seatunnel.api.source.Collector;
import org.apache.seatunnel.api.table.type.BasicType;
+import org.apache.seatunnel.api.table.type.MetadataUtil;
import org.apache.seatunnel.api.table.type.PrimitiveByteArrayType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
@@ -99,6 +100,12 @@ public class BinaryReadStrategy extends
AbstractReadStrategy {
// Read file in configurable chunks
readFileInChunks(inputStream, relativePath, tableId, output);
}
+ // Send an empty chunk as end-of-file marker
+ byte[] endMarker = new byte[0];
+ SeaTunnelRow endRow = new SeaTunnelRow(new Object[] {endMarker,
relativePath, -1L});
+ endRow.setTableId(tableId);
+ MetadataUtil.setBinaryRowComplete(endRow);
+ output.collect(endRow);
}
}
@@ -112,6 +119,7 @@ public class BinaryReadStrategy extends
AbstractReadStrategy {
byte[] fileContent = IOUtils.toByteArray(inputStream);
SeaTunnelRow row = new SeaTunnelRow(new Object[] {fileContent,
relativePath, 0L});
row.setTableId(tableId);
+ MetadataUtil.setBinaryFormat(row);
output.collect(row);
}
@@ -132,6 +140,7 @@ public class BinaryReadStrategy extends
AbstractReadStrategy {
SeaTunnelRow row = new SeaTunnelRow(new Object[] {buffer,
relativePath, partIndex});
buffer = new byte[binaryChunkSize];
row.setTableId(tableId);
+ MetadataUtil.setBinaryFormat(row);
output.collect(row);
partIndex++;
}
diff --git
a/seatunnel-connectors-v2/connector-file/connector-file-base/src/test/java/org/apache/seatunnel/connectors/seatunnel/file/reader/BinaryReadStrategyTest.java
b/seatunnel-connectors-v2/connector-file/connector-file-base/src/test/java/org/apache/seatunnel/connectors/seatunnel/file/reader/BinaryReadStrategyTest.java
index 90fc3d03bb..fd969c2095 100644
---
a/seatunnel-connectors-v2/connector-file/connector-file-base/src/test/java/org/apache/seatunnel/connectors/seatunnel/file/reader/BinaryReadStrategyTest.java
+++
b/seatunnel-connectors-v2/connector-file/connector-file-base/src/test/java/org/apache/seatunnel/connectors/seatunnel/file/reader/BinaryReadStrategyTest.java
@@ -70,7 +70,9 @@ public class BinaryReadStrategyTest {
List<SeaTunnelRow> rows = collector.getRows();
Assertions.assertEquals(
- 2, rows.size(), "Should have 2 chunks for 2048 bytes with
default 1024 chunk size");
+ 2 + 1,
+ rows.size(),
+ "Should have 3 chunks for 2048 bytes with default 1024 chunk
size");
// Verify first chunk
SeaTunnelRow firstRow = rows.get(0);
@@ -102,7 +104,7 @@ public class BinaryReadStrategyTest {
List<SeaTunnelRow> rows = collector.getRows();
Assertions.assertEquals(
- 3, rows.size(), "Should have 3 chunks for 1500 bytes with 512
chunk size");
+ 3 + 1, rows.size(), "Should have 4 chunks for 1500 bytes with
512 chunk size");
// Verify chunk sizes: 512, 512, 476
Assertions.assertEquals(512, ((byte[])
rows.get(0).getField(0)).length);
@@ -128,7 +130,7 @@ public class BinaryReadStrategyTest {
binaryReadStrategy.read(testFile.getAbsolutePath(), "test_table",
collector);
List<SeaTunnelRow> rows = collector.getRows();
- Assertions.assertEquals(1, rows.size(), "Should have 1 row in complete
file mode");
+ Assertions.assertEquals(1 + 1, rows.size(), "Should have 2 row in
complete file mode");
SeaTunnelRow row = rows.get(0);
byte[] fileData = (byte[]) row.getField(0);
diff --git
a/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-file-local-e2e/src/test/resources/binary/local_file_binary_to_assert.conf
b/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-file-local-e2e/src/test/resources/binary/local_file_binary_to_assert.conf
index c66d53f280..e8e9a8f8aa 100644
---
a/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-file-local-e2e/src/test/resources/binary/local_file_binary_to_assert.conf
+++
b/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-file-local-e2e/src/test/resources/binary/local_file_binary_to_assert.conf
@@ -32,7 +32,7 @@ sink {
row_rules = [
{
rule_type = MAX_ROW
- rule_value = 1924
+ rule_value = 1925
}
]
}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestEmbeddingIT.java
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestEmbeddingIT.java
index 034479a302..7140879da1 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestEmbeddingIT.java
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestEmbeddingIT.java
@@ -18,9 +18,12 @@
package org.apache.seatunnel.e2e.transform;
import org.apache.seatunnel.e2e.common.TestResource;
+import org.apache.seatunnel.e2e.common.container.ContainerExtendedFactory;
import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;
+import org.apache.seatunnel.e2e.common.junit.TestContainerExtension;
+import org.apache.seatunnel.e2e.common.util.ContainerUtil;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
@@ -79,6 +82,13 @@ public class TestEmbeddingIT extends TestSuiteBase
implements TestResource {
Startables.deepStart(Stream.of(mockserverContainer)).join();
}
+ @TestContainerExtension
+ private final ContainerExtendedFactory extendedFactory =
+ container -> {
+ ContainerUtil.copyFileIntoContainers(
+ "/binary/cat.png", "/seatunnel/read/binary/cat.png",
container);
+ };
+
@AfterAll
@Override
public void tearDown() throws Exception {
@@ -93,6 +103,14 @@ public class TestEmbeddingIT extends TestSuiteBase
implements TestResource {
Assertions.assertEquals(0, execResult.getExitCode());
}
+ @TestTemplate
+ public void testMultimodalEmbedding(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
+ container.executeJob("/embedding_transform_multimodal.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
+ }
+
@TestTemplate
public void testEmbeddingMultiTable(TestContainer container)
throws IOException, InterruptedException {
@@ -107,4 +125,19 @@ public class TestEmbeddingIT extends TestSuiteBase
implements TestResource {
Container.ExecResult execResult =
container.executeJob("/embedding_transform_custom.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}
+
+ @TestTemplate
+ public void testBinaryEmbeddingWithCompleteMode(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
+
container.executeJob("/embedding_transform_binary_complete_file.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
+ }
+
+ @TestTemplate
+ public void testBinaryEmbedding(TestContainer container)
+ throws IOException, InterruptedException {
+ Container.ExecResult execResult =
container.executeJob("/embedding_transform_binary.conf");
+ Assertions.assertEquals(0, execResult.getExitCode());
+ }
}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/binary/cat.png
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/binary/cat.png
new file mode 100644
index 0000000000..fb39446a11
Binary files /dev/null and
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/binary/cat.png
differ
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_binary.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_binary.conf
new file mode 100644
index 0000000000..486d69ed6b
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_binary.conf
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ LocalFile {
+ path = "/seatunnel/read/binary/"
+ file_format_type = "binary"
+ binary_complete_file_mode = false
+ binary_chunk_size = 1024
+ plugin_output = "binary_source"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "binary_source"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision-250615"
+ api_key = "test-api-key"
+ api_path = "http://mockserver:1080/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields = {
+ image_embedding = {
+ field = "data"
+ modality = "jpeg"
+ format = "binary"
+ }
+ }
+
+ plugin_output = "binary_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "binary_embedding_output"
+ rules = {
+ row_rules = [
+ {
+ rule_type = MAX_ROW
+ rule_value = 1
+ }
+ ],
+ field_rules = [
+ {
+ field_name = image_embedding
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = relativePath
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_binary_complete_file.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_binary_complete_file.conf
new file mode 100644
index 0000000000..9988e70dbf
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_binary_complete_file.conf
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ LocalFile {
+ path = "/seatunnel/read/binary/"
+ file_format_type = "binary"
+ binary_complete_file_mode = true
+ plugin_output = "binary_source"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "binary_source"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision-250615"
+ api_key = "test-api-key"
+ api_path = "http://mockserver:1080/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields = {
+ image_embedding = {
+ field = "data"
+ modality = "jpeg"
+ format = "binary"
+ }
+ }
+
+ plugin_output = "binary_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "binary_embedding_output"
+ rules = {
+ row_rules = [
+ {
+ rule_type = MAX_ROW
+ rule_value = 1
+ }
+ ],
+ field_rules = [
+ {
+ field_name = image_embedding
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = relativePath
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf
new file mode 100644
index 0000000000..efc7273145
--- /dev/null
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf
@@ -0,0 +1,243 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+env {
+ job.mode = "BATCH"
+}
+
+source {
+ FakeSource {
+ row.num = 5
+ schema = {
+ fields {
+ id = "int"
+ product_name = "string"
+ description = "string"
+ product_image_url = "string"
+ product_video_url = "string"
+ thumbnail_image = "string"
+ promotional_video = "string"
+ category = "string"
+ price = "double"
+ created_at = "timestamp"
+ }
+ }
+ rows = [
+ {
+ fields = [
+ 1,
+ "iPhone 15 Pro",
+ "Latest iPhone with advanced camera system and A17 Pro chip",
+ "https://example.com/images/iphone15pro.jpg",
+ "https://example.com/videos/iphone15pro_demo.mp4",
+ "https://example.com/thumbnails/iphone15pro_thumb.png",
+ "https://example.com/videos/iphone15pro_promo.mov",
+ "Electronics",
+ 999.99,
+ "2024-01-15T10:30:00"
+ ],
+ kind = INSERT
+ },
+ {
+ fields = [
+ 2,
+ "MacBook Air M3",
+ "Ultra-thin laptop with M3 chip for incredible performance",
+ "https://example.com/images/macbook_air_m3.jpeg",
+ "https://example.com/videos/macbook_air_review.avi",
+ "https://example.com/thumbnails/macbook_thumb.webp",
+ "https://example.com/videos/macbook_commercial.mp4",
+ "Computers",
+ 1299.99,
+ "2024-02-20T14:15:00"
+ ],
+ kind = INSERT
+ },
+ {
+ fields = [
+ 3,
+ "AirPods Pro 2",
+ "Wireless earbuds with active noise cancellation",
+ "https://example.com/images/airpods_pro2.gif",
+ "https://example.com/videos/airpods_demo.mp4",
+ "https://example.com/thumbnails/airpods_thumb.bmp",
+ "https://example.com/videos/airpods_ad.mov",
+ "Audio",
+ 249.99,
+ "2024-03-10T09:45:00"
+ ],
+ kind = INSERT
+ },
+ {
+ fields = [
+ 4,
+ "Apple Watch Series 9",
+ "Advanced health monitoring and fitness tracking smartwatch",
+ "https://example.com/images/apple_watch_s9.tiff",
+ "https://example.com/videos/watch_features.avi",
+ "https://example.com/thumbnails/watch_thumb.png",
+ "https://example.com/videos/watch_lifestyle.mp4",
+ "Wearables",
+ 399.99,
+ "2024-04-05T16:20:00"
+ ],
+ kind = INSERT
+ },
+ {
+ fields = [
+ 5,
+ "iPad Pro 12.9",
+ "Professional tablet with M2 chip and Liquid Retina XDR display",
+ "https://example.com/images/ipad_pro_129.jpg",
+ "https://example.com/videos/ipad_creative_demo.mov",
+ "https://example.com/thumbnails/ipad_thumb.jpeg",
+ "https://example.com/videos/ipad_productivity.avi",
+ "Tablets",
+ 1099.99,
+ "2024-05-12T11:30:00"
+ ],
+ kind = INSERT
+ }
+ ]
+ plugin_output = "fake"
+ }
+}
+
+transform {
+ Embedding {
+ plugin_input = "fake"
+ model_provider = DOUBAO
+ model = "doubao-embedding-vision-250615"
+ api_key = "xxxxxxxx"
+ api_path = "http://mockserver:1080/api/v3/embeddings/multimodal"
+ single_vectorized_input_number = 1
+
+ vectorization_fields {
+ description_vector = description
+
+ product_image_vector = {
+ field = product_image_url
+ modality = jpeg
+ format = url
+ }
+
+ thumbnail_vector = {
+ field = thumbnail_image
+ modality = png
+ format = url
+ }
+
+ demo_video_vector = {
+ field = product_video_url
+ modality = mp4
+ format = url
+ }
+
+ promo_video_vector = {
+ field = promotional_video
+ modality = mov
+ format = url
+ }
+
+ product_name_vector = product_name
+ }
+
+ plugin_output = "multimodal_embedding_output"
+ }
+}
+
+sink {
+ Assert {
+ plugin_input = "multimodal_embedding_output"
+ rules = {
+ field_rules = [
+ {
+ field_name = description_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = product_image_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = thumbnail_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = demo_video_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = promo_video_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = product_name_vector
+ field_type = float_vector
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = category
+ field_type = string
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ },
+ {
+ field_name = price
+ field_type = double
+ field_value = [
+ {
+ rule_type = NOT_NULL
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
diff --git
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mock-embedding.json
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mock-embedding.json
index af617f8080..8405939ec5 100644
---
a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mock-embedding.json
+++
b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mock-embedding.json
@@ -85,6 +85,45 @@
"Content-Type": "application/json"
}
}
+ },
+ {
+ "httpRequest": {
+ "method": "POST",
+ "path": "/api/v3/embeddings/multimodal",
+ "headers": {
+ "Authorization": [
+ "Bearer .*"
+ ],
+ "Content-Type": [
+ "application/json"
+ ]
+ }
+ },
+ "httpResponse": {
+ "body": {
+ "created": 1743575029,
+ "data": {
+ "embedding": [
+ -0.123046875, -0.35546875, -0.318359375, -0.255859375
+ ],
+ "object": "embedding"
+ },
+ "id": "021743575029461acbe49a31755bec77b2f09448eb15fa9a88e47",
+ "model": "doubao-embedding-vision-250615",
+ "object": "list",
+ "usage": {
+ "prompt_tokens": 13987,
+ "prompt_tokens_details": {
+ "image_tokens": 13800,
+ "text_tokens": 187
+ },
+ "total_tokens": 13987
+ }
+ },
+ "headers": {
+ "Content-Type": "application/json"
+ }
+ }
}
]
diff --git
a/seatunnel-engine/seatunnel-engine-server/src/test/java/org/apache/seatunnel/engine/server/checkpoint/CheckpointErrorRestoreEndTest.java
b/seatunnel-engine/seatunnel-engine-server/src/test/java/org/apache/seatunnel/engine/server/checkpoint/CheckpointErrorRestoreEndTest.java
index 643c486962..1aa22b4b9d 100644
---
a/seatunnel-engine/seatunnel-engine-server/src/test/java/org/apache/seatunnel/engine/server/checkpoint/CheckpointErrorRestoreEndTest.java
+++
b/seatunnel-engine/seatunnel-engine-server/src/test/java/org/apache/seatunnel/engine/server/checkpoint/CheckpointErrorRestoreEndTest.java
@@ -43,7 +43,7 @@ public class CheckpointErrorRestoreEndTest
JobMaster jobMaster =
server.getCoordinatorService().getJobMaster(jobId);
Assertions.assertEquals(1,
jobMaster.getPhysicalPlan().getPipelineList().size());
- await().atMost(120, TimeUnit.SECONDS)
+ await().atMost(240, TimeUnit.SECONDS)
.untilAsserted(
() ->
Assertions.assertEquals(
@@ -53,7 +53,7 @@ public class CheckpointErrorRestoreEndTest
.getPipelineList()
.get(0)
.getPipelineRestoreNum()));
- await().atMost(120, TimeUnit.SECONDS)
+ await().atMost(240, TimeUnit.SECONDS)
.untilAsserted(
() ->
Assertions.assertEquals(
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/MultipleFieldOutputTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/MultipleFieldOutputTransform.java
index f385b3cfd6..b5974e7e40 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/MultipleFieldOutputTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/common/MultipleFieldOutputTransform.java
@@ -22,6 +22,7 @@ import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.catalog.ConstraintKey;
import org.apache.seatunnel.api.table.catalog.TableIdentifier;
import org.apache.seatunnel.api.table.catalog.TableSchema;
+import org.apache.seatunnel.api.table.type.MetadataUtil;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowAccessor;
@@ -53,7 +54,11 @@ public abstract class MultipleFieldOutputTransform extends
AbstractCatalogSuppor
@Override
protected SeaTunnelRow transformRow(SeaTunnelRow inputRow) {
+
Object[] fieldValues = getOutputFieldValues(new
SeaTunnelRowAccessor(inputRow));
+ if (MetadataUtil.isBinaryFormat(inputRow) &&
!MetadataUtil.isComplete(inputRow)) {
+ return null;
+ }
SeaTunnelRow outputRow = rowContainerGenerator.apply(inputRow);
for (int i = 0; i < outputFieldNames.length; i++) {
outputRow.setField(fieldsIndex[i], fieldValues == null ? null :
fieldValues[i]);
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
index 8e26f85ece..23f0968f7b 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/ModelProvider.java
@@ -24,7 +24,8 @@ public enum ModelProvider {
OPENAI("https://api.openai.com/v1/chat/completions",
"https://api.openai.com/v1/embeddings"),
DOUBAO(
"https://ark.cn-beijing.volces.com/api/v3/chat/completions",
- "https://ark.cn-beijing.volces.com/api/v3/embeddings"),
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"),
QIANFAN("",
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"),
KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
DEEPSEEK("https://api.deepseek.com/chat/completions", ""),
@@ -37,10 +38,19 @@ public enum ModelProvider {
private final String LLMProviderPath;
private final String EmbeddingProviderPath;
+ private final String MultimodalEmbeddingProviderPath;
ModelProvider(String llmProviderPath, String embeddingProviderPath) {
+ this(llmProviderPath, embeddingProviderPath, "");
+ }
+
+ ModelProvider(
+ String llmProviderPath,
+ String embeddingProviderPath,
+ String multimodalEmbeddingProviderPath) {
LLMProviderPath = llmProviderPath;
EmbeddingProviderPath = embeddingProviderPath;
+ MultimodalEmbeddingProviderPath = multimodalEmbeddingProviderPath;
}
public String usedLLMPath(String path) {
@@ -50,9 +60,9 @@ public enum ModelProvider {
return path;
}
- public String usedEmbeddingPath(String path) {
+ public String usedEmbeddingPath(String path, boolean isMultimodalFields) {
if (StringUtils.isBlank(path)) {
- return EmbeddingProviderPath;
+ return isMultimodalFields ? MultimodalEmbeddingProviderPath :
EmbeddingProviderPath;
}
return path;
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
index 71310a1d54..8857e6264c 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java
@@ -23,6 +23,7 @@ import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
+import org.apache.seatunnel.api.table.type.MetadataUtil;
import org.apache.seatunnel.api.table.type.SeaTunnelRowAccessor;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.VectorType;
@@ -30,6 +31,8 @@ import
org.apache.seatunnel.transform.common.MultipleFieldOutputTransform;
import org.apache.seatunnel.transform.exception.TransformCommonError;
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalFieldValue;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalModel;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.Model;
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.amazon.BedrockModel;
import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.custom.CustomModel;
@@ -41,29 +44,39 @@ import
org.apache.seatunnel.transform.nlpmodel.llm.LLMTransformConfig;
import lombok.NonNull;
import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+@Slf4j
public class EmbeddingTransform extends MultipleFieldOutputTransform {
private final ReadonlyConfig config;
- private List<String> fieldNames;
private List<Integer> fieldOriginalIndexes;
private transient Model model;
private Integer dimension;
+ private boolean isMultimodalFields = false;
+ private Map<Integer, FieldSpec> fieldSpecMap;
+ private List<String> fieldNames;
+
+ private final Map<String, TreeMap<Long, byte[]>> binaryFileCache = new
ConcurrentHashMap<>();
+ private final Map<String, Long> partIndexMap = new ConcurrentHashMap<>();
public EmbeddingTransform(
@NonNull ReadonlyConfig config, @NonNull CatalogTable
inputCatalogTable) {
super(inputCatalogTable);
this.config = config;
- initOutputFields(
- inputCatalogTable.getTableSchema().toPhysicalRowDataType(),
- config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS));
+
initOutputFields(inputCatalogTable.getTableSchema().toPhysicalRowDataType(),
config);
}
private void tryOpen() {
@@ -74,8 +87,10 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
@Override
public void open() {
- // Initialize model
ModelProvider provider =
config.get(ModelTransformConfig.MODEL_PROVIDER);
+ String apiPath =
+ provider.usedEmbeddingPath(
+ config.get(ModelTransformConfig.API_PATH),
isMultimodalFields);
try {
switch (provider) {
case CUSTOM:
@@ -91,8 +106,7 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
model =
new CustomModel(
config.get(ModelTransformConfig.MODEL),
- provider.usedEmbeddingPath(
-
config.get(ModelTransformConfig.API_PATH)),
+ apiPath,
customConfig.get(
LLMTransformConfig.CustomRequestConfig
.CUSTOM_REQUEST_HEADERS),
@@ -111,8 +125,7 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
new OpenAIModel(
config.get(ModelTransformConfig.API_KEY),
config.get(ModelTransformConfig.MODEL),
- provider.usedEmbeddingPath(
-
config.get(ModelTransformConfig.API_PATH)),
+ apiPath,
config.get(
EmbeddingTransformConfig
.SINGLE_VECTORIZED_INPUT_NUMBER));
@@ -122,11 +135,11 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
new DoubaoModel(
config.get(ModelTransformConfig.API_KEY),
config.get(ModelTransformConfig.MODEL),
- provider.usedEmbeddingPath(
-
config.get(ModelTransformConfig.API_PATH)),
+ apiPath,
config.get(
EmbeddingTransformConfig
-
.SINGLE_VECTORIZED_INPUT_NUMBER));
+
.SINGLE_VECTORIZED_INPUT_NUMBER),
+ isMultimodalFields);
break;
case QIANFAN:
model =
@@ -134,8 +147,7 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
config.get(ModelTransformConfig.API_KEY),
config.get(ModelTransformConfig.SECRET_KEY),
config.get(ModelTransformConfig.MODEL),
- provider.usedEmbeddingPath(
-
config.get(ModelTransformConfig.API_PATH)),
+ apiPath,
config.get(ModelTransformConfig.OAUTH_PATH),
config.get(
EmbeddingTransformConfig
@@ -147,8 +159,7 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
new ZhipuModel(
config.get(ModelTransformConfig.API_KEY),
config.get(ModelTransformConfig.MODEL),
- provider.usedEmbeddingPath(
-
config.get(ModelTransformConfig.API_PATH)),
+ apiPath,
config.get(ModelTransformConfig.DIMENSION),
config.get(
EmbeddingTransformConfig
@@ -171,7 +182,12 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
default:
throw new IllegalArgumentException("Unsupported model
provider: " + provider);
}
- // Initialize dimension
+ if (isMultimodalFields && !(model instanceof MultimodalModel)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Model provider: %s does not support
multimodal embedding",
+ provider));
+ }
dimension = model.dimension();
} catch (IOException e) {
throw new RuntimeException("Failed to initialize model", e);
@@ -180,33 +196,55 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
}
}
- private void initOutputFields(SeaTunnelRowType inputRowType, Map<String,
String> fields) {
+ private void initOutputFields(SeaTunnelRowType inputRowType,
ReadonlyConfig config) {
+ Map<Integer, FieldSpec> fieldSpecMap = new HashMap<>();
List<String> fieldNames = new ArrayList<>();
- List<Integer> fieldOriginalIndexes = new ArrayList<>();
- for (Map.Entry<String, String> field : fields.entrySet()) {
- String srcField = field.getValue();
+ Map<String, Object> fieldsConfig =
+ config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS);
+ if (fieldsConfig == null || fieldsConfig.isEmpty()) {
+ throw new IllegalArgumentException("vectorization_fields
configuration is required");
+ }
+
+ for (Map.Entry<String, Object> field : fieldsConfig.entrySet()) {
+ FieldSpec fieldSpec = new FieldSpec(field);
+ log.info("Field spec: {}", fieldSpec.toString());
+ String srcField = fieldSpec.getFieldName();
int srcFieldIndex;
try {
srcFieldIndex = inputRowType.indexOf(srcField);
} catch (IllegalArgumentException e) {
throw
TransformCommonError.cannotFindInputFieldError(getPluginName(), srcField);
}
+ if (fieldSpec.isMultimodalField()) {
+ isMultimodalFields = true;
+ }
+ fieldSpecMap.put(srcFieldIndex, fieldSpec);
fieldNames.add(field.getKey());
- fieldOriginalIndexes.add(srcFieldIndex);
}
+ this.fieldSpecMap = fieldSpecMap;
this.fieldNames = fieldNames;
- this.fieldOriginalIndexes = fieldOriginalIndexes;
}
@Override
protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
tryOpen();
try {
- Object[] fieldArray = new Object[fieldOriginalIndexes.size()];
- for (int i = 0; i < fieldOriginalIndexes.size(); i++) {
- fieldArray[i] = inputRow.getField(fieldOriginalIndexes.get(i));
+ if (MetadataUtil.isBinaryFormat(inputRow)) {
+ return vectorizationBinaryRow(inputRow);
}
- List<ByteBuffer> vectorization = model.vectorization(fieldArray);
+ Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
+ Object[] fieldValues = new Object[fieldOriginalIndexes.size()];
+ List<ByteBuffer> vectorization;
+ int i = 0;
+
+ for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
+ FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
+ Object value = inputRow.getField(fieldOriginalIndex);
+ fieldValues[i++] =
+ isMultimodalFields ? new
MultimodalFieldValue(fieldSpec, value) : value;
+ }
+
+ vectorization = model.vectorization(fieldValues);
return vectorization.toArray();
} catch (Exception e) {
throw new RuntimeException("Failed to data vectorization", e);
@@ -217,6 +255,7 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
@VisibleForTesting
public Column[] getOutputColumns() {
tryOpen();
+ log.info("getOutputColumns: {}", fieldNames);
Column[] columns = new Column[fieldNames.size()];
for (int i = 0; i < fieldNames.size(); i++) {
columns[i] =
@@ -237,11 +276,107 @@ public class EmbeddingTransform extends
MultipleFieldOutputTransform {
return "Embedding";
}
+ public boolean isMultimodalFields() {
+ return isMultimodalFields;
+ }
+
+ /** Process a row in binary format: [data, relativePath, partIndex] */
+ private Object[] vectorizationBinaryRow(SeaTunnelRowAccessor inputRow)
throws Exception {
+
+ byte[] completeData = processBinaryRow(inputRow);
+ if (completeData == null) {
+ return null;
+ }
+ Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
+ Object[] fieldValues = new Object[fieldOriginalIndexes.size()];
+ int i = 0;
+
+ for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
+ FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
+ if (fieldSpec.isBinary()) {
+ fieldValues[i++] = new MultimodalFieldValue(fieldSpec,
completeData);
+ } else {
+ log.warn(
+ "Non-binary field {} configured in binary format data",
+ fieldSpec.getFieldName());
+ fieldValues[i++] = null;
+ }
+ }
+
+ try {
+ return model.vectorization(fieldValues).toArray();
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Failed to vectorize binary data for file: " +
inputRow.toString(), e);
+ }
+ }
+
+ private byte[] processBinaryRow(SeaTunnelRowAccessor inputRow) throws
Exception {
+ byte[] data = (byte[]) inputRow.getField(0);
+ String relativePath = (String) inputRow.getField(1);
+ long partIndex = (long) inputRow.getField(2);
+
+ if (partIndex != -1) {
+ checkPartOrder(relativePath, partIndex);
+ }
+ cacheBinaryChunk(relativePath, partIndex, data);
+ if (MetadataUtil.isComplete(inputRow)) {
+ byte[] completeFile = assembleCompleteFile(relativePath);
+ cleanupFileCache(relativePath);
+ log.info(
+ "Assembled complete file: {}, size: {} bytes",
+ relativePath,
+ completeFile.length);
+ return completeFile;
+ }
+ return null;
+ }
+
+ /** Validate that partIndex is in correct order for the given file */
+ private void checkPartOrder(String relativePath, long partIndex) throws
Exception {
+ Long lastPartIndex = partIndexMap.getOrDefault(relativePath, -1L);
+ if (partIndex - 1 != lastPartIndex) {
+ throw new Exception("Last order is " + lastPartIndex + ", but get
" + partIndex);
+ }
+ partIndexMap.put(relativePath, partIndex);
+ }
+
+ private void cacheBinaryChunk(String relativePath, long partIndex, byte[]
data) {
+ if (partIndex >= 0) {
+ binaryFileCache
+ .computeIfAbsent(relativePath, k -> new TreeMap<>())
+ .put(partIndex, data);
+ }
+ }
+
+ private byte[] assembleCompleteFile(String relativePath) {
+ TreeMap<Long, byte[]> chunks = binaryFileCache.get(relativePath);
+ try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream())
{
+ for (Map.Entry<Long, byte[]> entry : chunks.entrySet()) {
+ byte[] chunk = entry.getValue();
+ if (chunk.length > 0) {
+ outputStream.write(chunk);
+ }
+ }
+ return outputStream.toByteArray();
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to assemble complete file: " +
relativePath, e);
+ }
+ }
+
+ private void cleanupFileCache(String relativePath) {
+ binaryFileCache.remove(relativePath);
+ partIndexMap.remove(relativePath);
+ log.info("Cleaned up cache and partIndex tracking for file: {}",
relativePath);
+ }
+
@SneakyThrows
@Override
public void close() {
if (model != null) {
model.close();
}
+ binaryFileCache.clear();
+ partIndexMap.clear();
}
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformConfig.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformConfig.java
index a4b0c7a253..8193adc3a0 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformConfig.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransformConfig.java
@@ -17,6 +17,8 @@
package org.apache.seatunnel.transform.nlpmodel.embedding;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
+
import org.apache.seatunnel.api.configuration.Option;
import org.apache.seatunnel.api.configuration.Options;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
@@ -32,10 +34,14 @@ public class EmbeddingTransformConfig extends
ModelTransformConfig {
.withDescription(
"The number of single vectorized inputs, default
is 1 , which means 1 inputs will be vectorized in one request , eg: qianfan
only allows a maximum of 16 simultaneous messages, depending on your own
settings, etc");
- public static final Option<Map<String, String>> VECTORIZATION_FIELDS =
+ public static final Option<Map<String, Object>> VECTORIZATION_FIELDS =
Options.key("vectorization_fields")
- .mapType()
+ .type(new TypeReference<Map<String, Object>>() {})
.noDefaultValue()
.withDescription(
- "Specify the field vectorization relationship
between input and output");
+ "Specify the field vectorization relationship
between input and output. "
+ + "Supports multiple formats: "
+ + "1. String format: 'fieldName' (defaults
to text modality) "
+ + "2. Object format with modality and
format: {field: 'fieldName', modality: 'modalityType', format: 'formatType'} "
+ + "where modality can be 'image/jpeg',
'video/mp4' etc. , format can be 'url', 'binary'. ");
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/FieldSpec.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/FieldSpec.java
new file mode 100644
index 0000000000..94ee65329e
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/FieldSpec.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.embedding;
+
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.PayloadFormat;
+
+import lombok.Data;
+
+import java.io.Serializable;
+import java.util.Map;
+
+@Data
+public class FieldSpec implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ private String fieldName;
+ private ModalityType modalityType;
+ private PayloadFormat payloadFormat;
+
+ public FieldSpec(String fieldName) {
+ this.fieldName = fieldName;
+ this.modalityType = ModalityType.TEXT;
+ this.payloadFormat = PayloadFormat.TEXT;
+ }
+
+ public FieldSpec(Map.Entry<String, Object> fieldConfig) {
+ String outputFieldName = fieldConfig.getKey();
+ if (outputFieldName == null) {
+ throw new IllegalArgumentException("Field spec cannot be null");
+ }
+ Object fieldValue = fieldConfig.getValue();
+ try {
+ if (fieldValue instanceof String) {
+ parseBasicFieldSpec((String) fieldValue);
+ } else {
+ Map<String, Object> fieldSpecConfig = (Map<String, Object>)
fieldValue;
+ parseMultimodalFieldSpec(fieldSpecConfig);
+ }
+ } catch (Exception e) {
+ String errorMessage =
+ String.format(
+ "Invalid field spec for output field '%s': %s",
+ outputFieldName, fieldConfig);
+ throw new IllegalArgumentException(errorMessage, e);
+ }
+ }
+
+ /** Parse basic field spec: just the field name, defaults to TEXT modality
and default format */
+ private void parseBasicFieldSpec(String fieldSpec) {
+ if (fieldSpec == null || fieldSpec.trim().isEmpty()) {
+ throw new IllegalArgumentException("Field spec cannot be null or
empty");
+ }
+ this.fieldName = fieldSpec.trim();
+ this.modalityType = ModalityType.TEXT;
+ this.payloadFormat = PayloadFormat.TEXT;
+ }
+
+ /**
+ * Parse multimodal field spec: field name, modality, and format Supports
both formats: 1.
+ * Separate modality and format
+ */
+ private void parseMultimodalFieldSpec(Map<String, Object> fieldConfig) {
+ if (fieldConfig == null || fieldConfig.isEmpty()) {
+ throw new IllegalArgumentException("Field configuration cannot be
null or empty");
+ }
+
+ Object fieldNameObj = fieldConfig.get("field");
+ if (fieldNameObj == null) {
+ throw new IllegalArgumentException(
+ "Field name ('field') is required in field configuration");
+ }
+
+ this.fieldName = fieldNameObj.toString().trim();
+ if (this.fieldName.isEmpty()) {
+ throw new IllegalArgumentException("Field name cannot be empty");
+ }
+ Object modalityObj = fieldConfig.get("modality");
+ if (modalityObj != null) {
+ this.modalityType = ModalityType.ofName(modalityObj.toString());
+ Object formatObj = fieldConfig.get("format");
+ if (formatObj != null) {
+ this.payloadFormat =
PayloadFormat.ofName(formatObj.toString());
+ }
+ } else {
+ this.modalityType = ModalityType.TEXT;
+ Object formatObj = fieldConfig.get("format");
+ if (formatObj != null) {
+ this.payloadFormat =
PayloadFormat.ofName(formatObj.toString());
+ } else {
+ this.payloadFormat = PayloadFormat.TEXT;
+ }
+ }
+ }
+
+ public boolean isMultimodalField() {
+ return !ModalityType.TEXT.equals(modalityType);
+ }
+
+ public boolean isBinary() {
+ return PayloadFormat.BINARY.equals(payloadFormat);
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/ModalityType.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/ModalityType.java
new file mode 100644
index 0000000000..0d5f337348
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/ModalityType.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.embedding.multimodal;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.ToString;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Enumeration for multimodal modality types supported by embedding models */
+@AllArgsConstructor
+@Getter
+@ToString
+public enum ModalityType {
+ TEXT("text", ModalityGroup.TEXT, Arrays.asList("text")),
+ JPEG("jpeg", ModalityGroup.IMAGE, Arrays.asList("jpg", "jpeg")),
+ PNG("png", ModalityGroup.IMAGE, Arrays.asList("png", "apng")),
+ GIF("gif", ModalityGroup.IMAGE, Arrays.asList("gif")),
+ WEBP("webp", ModalityGroup.IMAGE, Arrays.asList("webp")),
+ BMP("bmp", ModalityGroup.IMAGE, Arrays.asList("bmp", "dib")),
+ TIFF("tiff", ModalityGroup.IMAGE, Arrays.asList("tiff", "tif")),
+ ICO("ico", ModalityGroup.IMAGE, Arrays.asList("ico")),
+ ICNS("icns", ModalityGroup.IMAGE, Arrays.asList("icns")),
+ SGI("sgi", ModalityGroup.IMAGE, Arrays.asList("sgi")),
+ JPEG2000(
+ "jpeg2000",
+ ModalityGroup.IMAGE,
+ Arrays.asList("j2c", "j2k", "jp2", "jpc", "jpf", "jpx")),
+
+ MP4("mp4", ModalityGroup.VIDEO, Arrays.asList("mp4")),
+ AVI("avi", ModalityGroup.VIDEO, Arrays.asList("avi")),
+ MOV("mov", ModalityGroup.VIDEO, Arrays.asList("mov"));
+
+ private final String name;
+ private final ModalityGroup group;
+ private final List<String> fileExtensions;
+
+ public static ModalityType ofName(String name) {
+ if (name == null || name.trim().isEmpty()) {
+ return null;
+ }
+
+ String trimmedName = name.trim().toLowerCase();
+ for (ModalityType type : ModalityType.values()) {
+ if (type.name.equalsIgnoreCase(trimmedName)) {
+ return type;
+ }
+ }
+
+ throw new IllegalArgumentException("Unsupported modality type: " +
name.trim());
+ }
+
+ /**
+ * Determine ModalityType from file extension/suffix If the value is not
binary format, analyze
+ * the file extension to determine the modality type
+ */
+ public static ModalityType fromFileSuffix(String value) {
+ if (value == null || value.trim().isEmpty()) {
+ return null;
+ }
+ String trimmedValue = value.trim().toLowerCase();
+ String extension = "";
+ int lastDotIndex = trimmedValue.lastIndexOf('.');
+ if (lastDotIndex > 0 && lastDotIndex < trimmedValue.length() - 1) {
+ extension = trimmedValue.substring(lastDotIndex + 1);
+ }
+ for (ModalityType type : ModalityType.values()) {
+ if (type.fileExtensions.contains(extension)) {
+ return type;
+ }
+ }
+ return null;
+ }
+
+ /** Get all supported file extensions for this modality type */
+ public List<String> getSupportedExtensions() {
+ return fileExtensions;
+ }
+
+ /** Check if this modality type supports the given file extension */
+ public boolean supportsExtension(String extension) {
+ if (extension == null) {
+ return false;
+ }
+ return fileExtensions.contains(extension.toLowerCase());
+ }
+
+ public enum ModalityGroup {
+ IMAGE,
+ VIDEO,
+ TEXT
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java
new file mode 100644
index 0000000000..01c3e50403
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.embedding.multimodal;
+
+import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec;
+
+import lombok.Getter;
+import lombok.extern.slf4j.Slf4j;
+
+import java.io.Serializable;
+import java.util.Base64;
+
+@Slf4j
+@Getter
+public class MultimodalFieldValue implements Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ private final FieldSpec fieldSpec;
+ private final Object value;
+
+ public MultimodalFieldValue(FieldSpec fieldSpec, Object value) {
+ this.value = value;
+ fieldSpec.setModalityType(determineModalityType(fieldSpec, value));
+ this.fieldSpec = fieldSpec;
+ }
+
+ /**
+ * Determine the actual modality type based on field spec and value If not
binary format,
+ * analyze the value suffix to determine modality type
+ */
+ private ModalityType determineModalityType(FieldSpec fieldSpec, Object
value) {
+
+ if (fieldSpec.isBinary()) {
+ return fieldSpec.getModalityType();
+ }
+ if (value != null) {
+ String valueStr = value.toString();
+ ModalityType detectedType = ModalityType.fromFileSuffix(valueStr);
+ if (detectedType != null) {
+ log.debug(
+ "Auto-detected modality type '{}' from value: {}",
detectedType, valueStr);
+ return detectedType;
+ }
+ }
+ return fieldSpec.getModalityType();
+ }
+
+ public String toBase64() {
+ if (value == null) {
+ throw new IllegalArgumentException("Binary data cannot be null or
empty");
+ }
+ return Base64.getEncoder().encodeToString(value.toString().getBytes());
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalModel.java
new file mode 100644
index 0000000000..57142e0308
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalModel.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.embedding.multimodal;
+
+import
org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
+
+import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Abstract base class for multimodal embedding models that can handle text,
image, and video data
+ */
+public abstract class MultimodalModel extends AbstractModel {
+
+ public MultimodalModel(Integer vectorizedNumber) {
+ super(vectorizedNumber);
+ }
+
+ @Override
+ protected final List<List<Float>> vector(Object[] fields) throws
IOException {
+ if (isMultimodalFields(fields)) {
+ return multimodalVector(fields);
+ } else {
+ return textVector(fields);
+ }
+ }
+
+ protected abstract List<List<Float>> textVector(Object[] fields) throws
IOException;
+
+ protected abstract List<List<Float>> multimodalVector(Object[] fields)
throws IOException;
+
+ /** Check if the given fields contain multimodal data */
+ @VisibleForTesting
+ public boolean isMultimodalFields(Object[] fields) {
+ if (fields == null || fields.length == 0) {
+ return false;
+ }
+ if (fields[0] instanceof MultimodalFieldValue) {
+ return true;
+ }
+ return false;
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/PayloadFormat.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/PayloadFormat.java
new file mode 100644
index 0000000000..60adcb3f4f
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/PayloadFormat.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.nlpmodel.embedding.multimodal;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.ToString;
+
+/** Enumeration for data formats supported by multimodal embedding models */
+@AllArgsConstructor
+@Getter
+@ToString
+public enum PayloadFormat {
+ URL("url"),
+ TEXT("text"),
+ BINARY("binary");
+
+ private final String name;
+
+ public static PayloadFormat ofName(String name) {
+ if (name == null || name.trim().isEmpty()) {
+ return URL;
+ }
+ for (PayloadFormat format : PayloadFormat.values()) {
+ if (format.name.equalsIgnoreCase(name.trim().toLowerCase())) {
+ return format;
+ }
+ }
+ String supportedFormats =
+ String.join(
+ ", ",
+ java.util.Arrays.stream(PayloadFormat.values())
+ .map(PayloadFormat::getName)
+ .toArray(String[]::new));
+
+ throw new IllegalArgumentException(
+ "Unsupported data format: "
+ + name.trim()
+ + ". Supported formats: "
+ + supportedFormats);
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
index 1994d5b51c..168c87b14e 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
@@ -32,7 +32,7 @@ public abstract class AbstractModel implements Model {
protected static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
protected static final String DIMENSION_EXAMPLE = "dimension example";
- private final Integer singleVectorizedInputNumber;
+ protected final Integer singleVectorizedInputNumber;
protected AbstractModel(Integer singleVectorizedInputNumber) {
this.singleVectorizedInputNumber = singleVectorizedInputNumber;
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
index 2174e61996..46250ac829 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java
@@ -23,7 +23,10 @@ import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
import
org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
-import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;
+import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalFieldValue;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalModel;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
@@ -34,19 +37,38 @@ import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-public class DoubaoModel extends AbstractModel {
+public class DoubaoModel extends MultimodalModel {
private final CloseableHttpClient client;
private final String apiKey;
private final String model;
private final String apiPath;
+ private final boolean isMultimodalFields;
+
+ private final String BASE64_PARAM_TEMPLATE = "data:%s/%s;base64,%s";
public DoubaoModel(String apiKey, String model, String apiPath, Integer
vectorizedNumber) {
- this(apiKey, model, apiPath, vectorizedNumber,
HttpClients.createDefault());
+ this(apiKey, model, apiPath, vectorizedNumber, false,
HttpClients.createDefault());
+ }
+
+ public DoubaoModel(
+ String apiKey,
+ String model,
+ String apiPath,
+ Integer vectorizedNumber,
+ boolean isMultimodalFields) {
+ this(
+ apiKey,
+ model,
+ apiPath,
+ vectorizedNumber,
+ isMultimodalFields,
+ HttpClients.createDefault());
}
public DoubaoModel(
@@ -54,25 +76,45 @@ public class DoubaoModel extends AbstractModel {
String model,
String apiPath,
Integer vectorizedNumber,
+ boolean isMultimodalFields,
CloseableHttpClient client) {
super(vectorizedNumber);
this.apiKey = apiKey;
this.model = model;
this.apiPath = apiPath;
+ this.isMultimodalFields = isMultimodalFields;
this.client = client;
}
@Override
- protected List<List<Float>> vector(Object[] fields) throws IOException {
- return vectorGeneration(fields);
+ protected List<List<Float>> textVector(Object[] fields) throws IOException
{
+ return textVectorGeneration(fields);
+ }
+
+ @Override
+ public List<List<Float>> multimodalVector(Object[] fields) throws
IOException {
+ if (singleVectorizedInputNumber > 1) {
+ throw new IllegalArgumentException(
+ "Doubao does not support batch multimodal vectorization in
a single request. ");
+ }
+ List<List<Float>> vectors = new ArrayList<>();
+ for (Object field : fields) {
+ vectors.add(multimodalVectorGeneration((MultimodalFieldValue)
field));
+ }
+ return vectors;
}
@Override
public Integer dimension() throws IOException {
- return vectorGeneration(new Object[]
{DIMENSION_EXAMPLE}).get(0).size();
+ return isMultimodalFields
+ ? multimodalVectorGeneration(
+ new MultimodalFieldValue(
+ new FieldSpec(DIMENSION_EXAMPLE),
DIMENSION_EXAMPLE))
+ .size()
+ : textVectorGeneration(new Object[]
{DIMENSION_EXAMPLE}).get(0).size();
}
- private List<List<Float>> vectorGeneration(Object[] fields) throws
IOException {
+ private List<List<Float>> textVectorGeneration(Object[] fields) throws
IOException {
HttpPost post = new HttpPost(apiPath);
post.setHeader("Authorization", "Bearer " + apiKey);
post.setHeader("Content-Type", "application/json");
@@ -111,6 +153,109 @@ public class DoubaoModel extends AbstractModel {
return OBJECT_MAPPER.createObjectNode().put("model",
model).set("input", arrayNode);
}
+ protected List<Float> multimodalVectorGeneration(MultimodalFieldValue
field)
+ throws IOException {
+
+ HttpPost httpPost = new HttpPost(apiPath);
+ httpPost.setHeader("Authorization", "Bearer " + apiKey);
+ httpPost.setHeader("Content-Type", "application/json");
+
+ StringEntity entity =
+ new StringEntity(
+
OBJECT_MAPPER.writeValueAsString(multimodalBody(field)),
+ StandardCharsets.UTF_8);
+ httpPost.setEntity(entity);
+
+ try (CloseableHttpResponse response = client.execute(httpPost)) {
+ String responseBody =
+ EntityUtils.toString(response.getEntity(),
StandardCharsets.UTF_8);
+
+ if (response.getStatusLine().getStatusCode() != 200) {
+ throw new IOException(
+ "HTTP error "
+ + response.getStatusLine().getStatusCode()
+ + ": "
+ + responseBody);
+ }
+
+ return parseMultimodalVectorResponse(responseBody);
+ }
+ }
+
+ @VisibleForTesting
+ public List<Float> parseMultimodalVectorResponse(String responseBody)
throws IOException {
+ JsonNode responseJson = OBJECT_MAPPER.readTree(responseBody);
+ if (responseJson.has("error")) {
+ JsonNode error = responseJson.get("error");
+ String errorMessage =
+ error.has("message") ? error.get("message").asText() :
"Unknown error";
+ throw new IOException("API error: " + errorMessage);
+ }
+
+ JsonNode dataNode = responseJson.get("data");
+ if (dataNode == null) {
+ throw new IOException("Invalid response format: missing or invalid
'data' field");
+ }
+
+ JsonNode embeddingArray = dataNode.get("embedding");
+ if (embeddingArray == null || !embeddingArray.isArray()) {
+ throw new IOException("Invalid response format: missing or invalid
'embedding' field");
+ }
+
+ List<Float> vector = new ArrayList<>();
+ for (JsonNode value : embeddingArray) {
+ vector.add(value.floatValue());
+ }
+ return vector;
+ }
+
+ @VisibleForTesting
+ public ObjectNode multimodalBody(MultimodalFieldValue field) {
+ ObjectNode requestNode = OBJECT_MAPPER.createObjectNode();
+ requestNode.put("model", model);
+ requestNode.put("encoding_format", "float");
+ ArrayNode inputDatas = OBJECT_MAPPER.createArrayNode();
+ inputDatas.add(inputRawData(field));
+ requestNode.set("input", inputDatas);
+ return requestNode;
+ }
+
+ protected ObjectNode inputRawData(MultimodalFieldValue field) {
+ ObjectNode rawDataNode = OBJECT_MAPPER.createObjectNode();
+ FieldSpec fieldSpec = field.getFieldSpec();
+ String fieldValue = field.getValue().toString().trim();
+ ModalityType fieldSpecModalityType = fieldSpec.getModalityType();
+ String modalityParamName = getModalityParamName(fieldSpecModalityType);
+ rawDataNode.put("type", modalityParamName);
+ if (ModalityType.TEXT == fieldSpecModalityType) {
+ rawDataNode.put(modalityParamName, fieldValue);
+ return rawDataNode;
+ }
+
+ if (fieldSpec.isBinary()) {
+ fieldValue =
+ String.format(
+ BASE64_PARAM_TEMPLATE,
+
fieldSpecModalityType.getGroup().name().toLowerCase(),
+ fieldSpecModalityType.getName(),
+ field.toBase64());
+ }
+ rawDataNode.set(modalityParamName,
OBJECT_MAPPER.createObjectNode().put("url", fieldValue));
+
+ return rawDataNode;
+ }
+
+ private String getModalityParamName(ModalityType inputType) {
+ switch (inputType.getGroup()) {
+ case IMAGE:
+ return "image_url";
+ case VIDEO:
+ return "video_url";
+ default:
+ return "text";
+ }
+ }
+
@Override
public void close() throws IOException {
if (client != null) {
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java
new file mode 100644
index 0000000000..b5e4689e63
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
+import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
+
+import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalFieldValue;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class DoubaoMultimodalModelTest {
+
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ @Test
+ void testMultimodalBodyWithText() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ Map.Entry<String, Object> textFieldEntry =
+ new java.util.AbstractMap.SimpleEntry<>("text_vector", "Hello
world");
+ FieldSpec fieldSpec = new FieldSpec(textFieldEntry);
+ MultimodalFieldValue multimodalFieldValue =
+ new MultimodalFieldValue(fieldSpec, "Hello world");
+
+ ObjectNode result = model.multimodalBody(multimodalFieldValue);
+
+ Assertions.assertEquals("doubao-embedding-vision",
result.get("model").asText());
+ Assertions.assertEquals("float",
result.get("encoding_format").asText());
+ Assertions.assertEquals(1, result.get("input").size());
+
+ ObjectNode inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("text", inputNode.get("type").asText());
+ Assertions.assertEquals("Hello world", inputNode.get("text").asText());
+ Assertions.assertFalse(inputNode.has("image_url"));
+ Assertions.assertFalse(inputNode.has("video_url"));
+
+ model.close();
+ }
+
+ /**
+ * { "model" : "doubao-embedding-vision", "encoding_format" : "float",
"input" : [ { "type" :
+ * "image_url", "image_url" : { "url" :
+ *
"https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" } }]
}
+ */
+ @Test
+ void testMultimodalBodyWithImage() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ Map<String, Object> imageFieldConfig = new HashMap<>();
+ imageFieldConfig.put("field", "image_field");
+ imageFieldConfig.put("modality", "jpeg");
+ imageFieldConfig.put("format", "url");
+
+ Map.Entry<String, Object> imageFieldEntry =
+ new java.util.AbstractMap.SimpleEntry<>("image_vector",
imageFieldConfig);
+ FieldSpec fieldSpec = new FieldSpec(imageFieldEntry);
+ MultimodalFieldValue multimodalFieldValue =
+ new MultimodalFieldValue(
+ fieldSpec,
+
"https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg");
+
+ ObjectNode result = model.multimodalBody(multimodalFieldValue);
+
+ // Verify the request structure
+ Assertions.assertEquals("doubao-embedding-vision",
result.get("model").asText());
+ Assertions.assertEquals("float",
result.get("encoding_format").asText());
+ Assertions.assertTrue(result.get("input").isArray());
+ Assertions.assertEquals(1, result.get("input").size());
+
+ ObjectNode inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("image_url", inputNode.get("type").asText());
+ Assertions.assertTrue(inputNode.has("image_url"));
+ Assertions.assertEquals(
+
"https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg",
+ inputNode.get("image_url").get("url").asText());
+ Assertions.assertFalse(inputNode.has("text"));
+ Assertions.assertFalse(inputNode.has("video_url"));
+
+ model.close();
+ }
+
+ /**
+ * { "model" : "doubao-embedding-vision", "encoding_format" : "float",
"input" : [ { "type" :
+ * "video_url", "video_url" : { "url" : "https://example.com/video.mp4" }
} ] }
+ */
+ @Test
+ void testMultimodalBodyWithVideo() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ Map<String, Object> videoFieldConfig = new HashMap<>();
+ videoFieldConfig.put("field", "video_field");
+ videoFieldConfig.put("modality", "mP4");
+ videoFieldConfig.put("format", "url");
+
+ Map.Entry<String, Object> videoFieldEntry =
+ new java.util.AbstractMap.SimpleEntry<>("video_vector",
videoFieldConfig);
+ FieldSpec fieldSpec = new FieldSpec(videoFieldEntry);
+ MultimodalFieldValue multimodalFieldValue =
+ new MultimodalFieldValue(fieldSpec,
"https://example.com/video.mp4");
+
+ ObjectNode result = model.multimodalBody(multimodalFieldValue);
+
+ Assertions.assertEquals("doubao-embedding-vision",
result.get("model").asText());
+ Assertions.assertEquals("float",
result.get("encoding_format").asText());
+ Assertions.assertEquals(1, result.get("input").size());
+
+ ObjectNode inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("video_url", inputNode.get("type").asText());
+ Assertions.assertTrue(inputNode.has("video_url"));
+ Assertions.assertEquals(
+ "https://example.com/video.mp4",
inputNode.get("video_url").get("url").asText());
+ Assertions.assertFalse(inputNode.has("text"));
+ Assertions.assertFalse(inputNode.has("image_url"));
+
+ model.close();
+ }
+
+ /**
+ * { "type": "image_url", "image_url": { "url":
+ * f"data:image/<IMAGE_FORMAT>;base64,{base64_image}" } }
+ */
+ @Test
+ void testMultimodalBodyWithBinaryImage() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision-250615",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ Map<String, Object> binaryImageFieldConfig = new HashMap<>();
+ binaryImageFieldConfig.put("field", "binary_image_field");
+ binaryImageFieldConfig.put("modality", "png");
+ binaryImageFieldConfig.put("format", "binary");
+
+ Map.Entry<String, Object> binaryImageFieldEntry =
+ new java.util.AbstractMap.SimpleEntry<>(
+ "binary_image_vector", binaryImageFieldConfig);
+ FieldSpec fieldSpec = new FieldSpec(binaryImageFieldEntry);
+
+ byte[] mockImageData = "mock-image-data".getBytes();
+ MultimodalFieldValue multimodalFieldValue =
+ new MultimodalFieldValue(fieldSpec, mockImageData);
+
+ ObjectNode result = model.multimodalBody(multimodalFieldValue);
+
+ Assertions.assertEquals("doubao-embedding-vision-250615",
result.get("model").asText());
+ Assertions.assertEquals("float",
result.get("encoding_format").asText());
+ Assertions.assertEquals(1, result.get("input").size());
+
+ ObjectNode inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("image_url", inputNode.get("type").asText());
+ Assertions.assertTrue(inputNode.has("image_url"));
+
+ model.close();
+ }
+
+ @Test
+ void testParseMultimodalVectorResponseSuccess() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ String successResponse =
+ "{\n"
+ + " \"created\": 1743575029,\n"
+ + " \"data\": {\n"
+ + " \"embedding\": [\n"
+ + " -0.123046875, -0.35546875, -0.318359375,
0.255859375, 1.5\n"
+ + " ],\n"
+ + " \"object\": \"embedding\"\n"
+ + " },\n"
+ + " \"id\":
\"021743575029461acbe49a31755bec77b2f09448eb15fa9a88e47\",\n"
+ + " \"model\": \"doubao-embedding-vision-250615\",\n"
+ + " \"object\": \"list\",\n"
+ + " \"usage\": {\n"
+ + " \"prompt_tokens\": 13987,\n"
+ + " \"prompt_tokens_details\": {\n"
+ + " \"image_tokens\": 13800,\n"
+ + " \"text_tokens\": 187\n"
+ + " },\n"
+ + " \"total_tokens\": 13987\n"
+ + " }\n"
+ + "}";
+
+ List<Float> result =
model.parseMultimodalVectorResponse(successResponse);
+
+ // Verify the parsed vector
+ Assertions.assertNotNull(result);
+ Assertions.assertEquals(5, result.size());
+ Assertions.assertEquals(-0.123046875f, result.get(0), 0.0001f);
+ Assertions.assertEquals(-0.35546875f, result.get(1), 0.0001f);
+ Assertions.assertEquals(-0.318359375f, result.get(2), 0.0001f);
+ Assertions.assertEquals(0.255859375f, result.get(3), 0.0001f);
+ Assertions.assertEquals(1.5f, result.get(4), 0.0001f);
+
+ model.close();
+ }
+
+ @Test
+ void testUrlAutoDetectModality() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ Map<String, Object> fieldConfig = new HashMap<>();
+ fieldConfig.put("field", "image_field");
+ fieldConfig.put("format", "url");
+ fieldConfig.put("modality", "png");
+ Map.Entry<String, Object> fieldEntry =
+ new java.util.AbstractMap.SimpleEntry<>("image_vector",
fieldConfig);
+ FieldSpec fieldSpec = new FieldSpec(fieldEntry);
+
+ MultimodalFieldValue multimodalFieldValue =
+ new MultimodalFieldValue(fieldSpec,
"https://example.com/photo.jpg");
+
+ Assertions.assertEquals(
+ ModalityType.JPEG,
multimodalFieldValue.getFieldSpec().getModalityType());
+ ObjectNode result = model.multimodalBody(multimodalFieldValue);
+ ObjectNode inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("image_url", inputNode.get("type").asText());
+
+ Map<String, Object> fieldConfig2 = new HashMap<>();
+ fieldConfig2.put("field", "image_field");
+ fieldConfig2.put("format", "url");
+ fieldEntry = new java.util.AbstractMap.SimpleEntry<>("image_vector",
fieldConfig2);
+ fieldSpec = new FieldSpec(fieldEntry);
+
+ multimodalFieldValue = new MultimodalFieldValue(fieldSpec,
"https://example.com/photo.jpg");
+
+ Assertions.assertEquals(
+ ModalityType.JPEG,
multimodalFieldValue.getFieldSpec().getModalityType());
+ result = model.multimodalBody(multimodalFieldValue);
+ inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("image_url", inputNode.get("type").asText());
+
+ model.close();
+ }
+
+ @Test
+ void testBinaryAutoDetectModality() throws IOException {
+ DoubaoModel model =
+ new DoubaoModel(
+ "test-api-key",
+ "doubao-embedding-vision",
+ "https://ark.cn-beijing.volces.com/api/v3/embeddings",
+ 1);
+
+ Map<String, Object> fieldConfig = new HashMap<>();
+ fieldConfig.put("field", "image_field");
+ fieldConfig.put("format", "binary");
+ fieldConfig.put("modality", "png");
+ Map.Entry<String, Object> fieldEntry =
+ new java.util.AbstractMap.SimpleEntry<>("image_vector",
fieldConfig);
+ FieldSpec fieldSpec = new FieldSpec(fieldEntry);
+
+ MultimodalFieldValue multimodalFieldValue =
+ new MultimodalFieldValue(fieldSpec,
"https://example.com/photo.jpg");
+
+ Assertions.assertEquals(
+ ModalityType.PNG,
multimodalFieldValue.getFieldSpec().getModalityType());
+ ObjectNode result = model.multimodalBody(multimodalFieldValue);
+ ObjectNode inputNode = (ObjectNode) result.get("input").get(0);
+ Assertions.assertEquals("image_url", inputNode.get("type").asText());
+
+ model.close();
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
index efb65bb341..c806dc3d83 100644
---
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingModelDimensionTest.java
@@ -82,6 +82,7 @@ public class EmbeddingModelDimensionTest {
"modelName",
"https://api.doubao.io/v1/chat/completions",
1,
+ false,
client);
int dimension = ThreadLocalRandom.current().nextInt(1024, 2561);
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java
new file mode 100644
index 0000000000..c97372f8fe
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.PayloadFormat;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.AbstractMap;
+import java.util.HashMap;
+import java.util.Map;
+
+public class FieldSpecTest {
+
+ @Test
+ void testMapEntryConstructorWithStringValue() {
+ Map.Entry<String, Object> entry =
+ new AbstractMap.SimpleEntry<>("book_intro_vector",
"book_intro");
+ FieldSpec fieldSpec = new FieldSpec(entry);
+ Assertions.assertEquals("book_intro", fieldSpec.getFieldName());
+ Assertions.assertEquals(ModalityType.TEXT,
fieldSpec.getModalityType());
+ Assertions.assertEquals(PayloadFormat.TEXT,
fieldSpec.getPayloadFormat());
+ Assertions.assertFalse(fieldSpec.isMultimodalField());
+ Assertions.assertFalse(fieldSpec.isBinary());
+ }
+
+ @Test
+ void testMapEntryConstructorWithStringValueTrimming() {
+ Map.Entry<String, Object> entry =
+ new AbstractMap.SimpleEntry<>("book_intro_vector", "
book_intro ");
+ FieldSpec fieldSpec = new FieldSpec(entry);
+ Assertions.assertEquals("book_intro", fieldSpec.getFieldName());
+ Assertions.assertEquals(ModalityType.TEXT,
fieldSpec.getModalityType());
+ Assertions.assertEquals(PayloadFormat.TEXT,
fieldSpec.getPayloadFormat());
+ }
+
+ @Test
+ void testMapEntryConstructorWithNullKey() {
+ Map.Entry<String, Object> entry = new AbstractMap.SimpleEntry<>(null,
"book_intro");
+ IllegalArgumentException exception =
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
new FieldSpec(entry));
+ Assertions.assertTrue(exception.getMessage().contains("Field spec
cannot be null"));
+ }
+
+ @Test
+ void testMapEntryConstructorWithEmpty() {
+ Map.Entry<String, Object> entry = new
AbstractMap.SimpleEntry<>("book_intro_vector", null);
+ IllegalArgumentException exception =
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
new FieldSpec(entry));
+ Assertions.assertTrue(
+ exception.getMessage().contains("Invalid field spec for output
field"));
+
+ Map.Entry<String, Object> entry2 = new
AbstractMap.SimpleEntry<>("book_intro_vector", "");
+ exception =
+ Assertions.assertThrows(
+ IllegalArgumentException.class, () -> new
FieldSpec(entry2));
+ Assertions.assertTrue(
+ exception.getMessage().contains("Invalid field spec for output
field"));
+ }
+
+ @Test
+ void testMapEntryConstructorWithMapValue() {
+
+ Map<String, Object> fieldConfig = new HashMap<>();
+ fieldConfig.put("field", "book_image");
+ fieldConfig.put("modality", "jpeg");
+ fieldConfig.put("format", "binary");
+
+ Map.Entry<String, Object> entry = new
AbstractMap.SimpleEntry<>("book_field", fieldConfig);
+
+ FieldSpec fieldSpec = new FieldSpec(entry);
+
+ Assertions.assertEquals("book_image", fieldSpec.getFieldName());
+ Assertions.assertEquals(ModalityType.JPEG,
fieldSpec.getModalityType());
+ Assertions.assertEquals(PayloadFormat.BINARY,
fieldSpec.getPayloadFormat());
+ Assertions.assertTrue(fieldSpec.isMultimodalField());
+ Assertions.assertTrue(fieldSpec.isBinary());
+ }
+
+ @Test
+ void testMapEntryConstructorWithMapValueNoModality() {
+ Map<String, Object> fieldConfig = new HashMap<>();
+ fieldConfig.put("field", "book_intro");
+ fieldConfig.put("modality", "text");
+ fieldConfig.put("format", "text");
+
+ Map.Entry<String, Object> entry = new
AbstractMap.SimpleEntry<>("book_field", fieldConfig);
+
+ FieldSpec fieldSpec = new FieldSpec(entry);
+
+ Assertions.assertEquals("book_intro", fieldSpec.getFieldName());
+ Assertions.assertEquals(ModalityType.TEXT,
fieldSpec.getModalityType());
+ Assertions.assertEquals(PayloadFormat.TEXT,
fieldSpec.getPayloadFormat());
+ Assertions.assertFalse(fieldSpec.isMultimodalField());
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/MultimodalConfigTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/MultimodalConfigTest.java
new file mode 100644
index 0000000000..ba5eae1f71
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/MultimodalConfigTest.java
@@ -0,0 +1,375 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.embedding;
+
+import org.apache.seatunnel.api.configuration.ReadonlyConfig;
+import org.apache.seatunnel.api.table.catalog.CatalogTable;
+import org.apache.seatunnel.api.table.catalog.Column;
+import org.apache.seatunnel.api.table.catalog.PhysicalColumn;
+import org.apache.seatunnel.api.table.catalog.TableIdentifier;
+import org.apache.seatunnel.api.table.catalog.TableSchema;
+import org.apache.seatunnel.api.table.type.BasicType;
+import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
+import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;
+import org.apache.seatunnel.transform.nlpmodel.embedding.EmbeddingTransform;
+import
org.apache.seatunnel.transform.nlpmodel.embedding.EmbeddingTransformConfig;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+public class MultimodalConfigTest {
+
+ private CatalogTable createTestCatalogTable() {
+ Column[] columns = {
+ PhysicalColumn.of("text_field", BasicType.STRING_TYPE, 255L, true,
null, ""),
+ PhysicalColumn.of("image_field", BasicType.STRING_TYPE, 255L,
true, null, ""),
+ PhysicalColumn.of("video_field", BasicType.STRING_TYPE, 255L,
true, null, ""),
+ PhysicalColumn.of("mixed_field", BasicType.STRING_TYPE, 255L,
true, null, "")
+ };
+
+ TableSchema tableSchema =
TableSchema.builder().columns(Arrays.asList(columns)).build();
+ return CatalogTable.of(
+ TableIdentifier.of("test", "test", "test_table"),
+ tableSchema,
+ new HashMap<>(),
+ new ArrayList<>(),
+ "Test table for multimodal embedding");
+ }
+
+ @Test
+ void testIsMultimodalFieldsDetectionWithTextOnly() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Only text fields - should not be multimodal
+ Map<String, Object> vectorizationFields = new HashMap<>();
+ vectorizationFields.put("text_vector", "text_field"); // Default to
text type
+
+ // Explicitly text type using object format
+ Map<String, Object> textFieldConfig = new HashMap<>();
+ textFieldConfig.put("field", "mixed_field");
+ textFieldConfig.put("modality", "text");
+ vectorizationFields.put("text_vector2", textFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+ EmbeddingTransform transform = new EmbeddingTransform(config,
catalogTable);
+
+ Assertions.assertNotNull(transform);
+ Assertions.assertFalse(transform.isMultimodalFields());
+ }
+
+ @Test
+ void testIsMultimodalFieldsDetectionWithImageField() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Include image field - should be multimodal
+ Map<String, Object> vectorizationFields = new HashMap<>();
+ vectorizationFields.put("text_vector", "text_field");
+
+ // Image type using object format (use specific image format)
+ Map<String, Object> imageFieldConfig = new HashMap<>();
+ imageFieldConfig.put("field", "image_field");
+ imageFieldConfig.put("modality", "jpeg");
+ imageFieldConfig.put("format", "url");
+ vectorizationFields.put("image_vector", imageFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+ EmbeddingTransform transform = new EmbeddingTransform(config,
catalogTable);
+ Assertions.assertNotNull(transform);
+ Assertions.assertTrue(transform.isMultimodalFields());
+ }
+
+ @Test
+ void testIsMultimodalFieldsDetectionWithVideoField() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Include video field - should be multimodal
+ Map<String, Object> vectorizationFields = new HashMap<>();
+ vectorizationFields.put("text_vector", "text_field");
+
+ // Video type using object format (use specific video format)
+ Map<String, Object> videoFieldConfig = new HashMap<>();
+ videoFieldConfig.put("field", "video_field");
+ videoFieldConfig.put("modality", "mp4");
+ videoFieldConfig.put("format", "url");
+ vectorizationFields.put("video_vector", videoFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ EmbeddingTransform transform = new EmbeddingTransform(config,
catalogTable);
+ Assertions.assertNotNull(transform);
+ Assertions.assertTrue(transform.isMultimodalFields());
+ }
+
+ @Test
+ void testIsMultimodalFieldsDetectionWithMixedFields() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Include multiple modality types - should be multimodal
+ Map<String, Object> vectorizationFields = new HashMap<>();
+
+ // Text field using object format
+ Map<String, Object> textFieldConfig = new HashMap<>();
+ textFieldConfig.put("field", "text_field");
+ textFieldConfig.put("modality", "text");
+ vectorizationFields.put("text_vector", textFieldConfig);
+
+ // Image field using object format (use specific image format)
+ Map<String, Object> imageFieldConfig = new HashMap<>();
+ imageFieldConfig.put("field", "image_field");
+ imageFieldConfig.put("modality", "png");
+ imageFieldConfig.put("format", "url");
+ vectorizationFields.put("image_vector", imageFieldConfig);
+
+ // Video field using object format (use specific video format)
+ Map<String, Object> videoFieldConfig = new HashMap<>();
+ videoFieldConfig.put("field", "video_field");
+ videoFieldConfig.put("modality", "avi");
+ videoFieldConfig.put("format", "url");
+ vectorizationFields.put("video_vector", videoFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ // This should work since DOUBAO supports multimodal
+ EmbeddingTransform transform = new EmbeddingTransform(config,
catalogTable);
+ Assertions.assertNotNull(transform);
+ Assertions.assertTrue(transform.isMultimodalFields());
+ }
+
+ @Test
+ void testMultimodalModelValidationFailure() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ // Use a provider that doesn't support multimodal (e.g., OPENAI
text-only models)
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.OPENAI.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"text-embedding-3-small");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.openai.com/v1/embeddings");
+
+ Map<String, Object> vectorizationFields = new HashMap<>();
+ Map<String, Object> imageFieldConfig = new HashMap<>();
+ imageFieldConfig.put("field", "image_field");
+ imageFieldConfig.put("modality", "webp");
+ imageFieldConfig.put("format", "url");
+ vectorizationFields.put("image_vector", imageFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ // Should throw IllegalArgumentException when opening
+ EmbeddingTransform transform = new EmbeddingTransform(config,
catalogTable);
+ IllegalArgumentException exception =
+ Assertions.assertThrows(IllegalArgumentException.class,
transform::open);
+
+ Assertions.assertTrue(exception.getMessage().contains("does not
support multimodal"));
+ }
+
+ @Test
+ void testMultimodalDetectionWithDefaultTextType() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.OPENAI.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Fields without explicit type specification default to text
+ Map<String, Object> vectorizationFields = new HashMap<>();
+ vectorizationFields.put("text_vector1", "text_field");
+ vectorizationFields.put("text_vector2", "mixed_field");
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ // Should not be detected as multimodal since all fields default to
text
+ EmbeddingTransform transform = new EmbeddingTransform(config,
catalogTable);
+ Assertions.assertNotNull(transform);
+ Assertions.assertFalse(transform.isMultimodalFields());
+ }
+
+ @Test
+ void testMultimodalDetectionWithInvalidModalityType() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ Map<String, Object> vectorizationFields = new HashMap<>();
+
+ // Invalid modality type using object format
+ Map<String, Object> invalidFieldConfig = new HashMap<>();
+ invalidFieldConfig.put("field", "text_field");
+ invalidFieldConfig.put("modality", "audio");
+ vectorizationFields.put("invalid_vector", invalidFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ // Should throw exception due to unsupported modality type
+ IllegalArgumentException exception =
+ Assertions.assertThrows(
+ IllegalArgumentException.class,
+ () -> new EmbeddingTransform(config, catalogTable));
+ Assertions.assertTrue(exception.getMessage().contains("Invalid field
spec"));
+ }
+
+ @Test
+ void testMultimodalDetectionWithNonExistentField() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ Map<String, Object> vectorizationFields = new HashMap<>();
+
+ Map<String, Object> nonExistentFieldConfig = new HashMap<>();
+ nonExistentFieldConfig.put("field", "nonexistent_field");
+ nonExistentFieldConfig.put("modality", "gif");
+ vectorizationFields.put("nonexistent_vector", nonExistentFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ RuntimeException exception =
+ Assertions.assertThrows(
+ RuntimeException.class, () -> new
EmbeddingTransform(config, catalogTable));
+ Assertions.assertTrue(
+ exception
+ .getMessage()
+ .contains("'Embedding' transform not found in upstream
schema"));
+ }
+
+ @Test
+ void testMultimodalDetectionCaseSensitivity() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Test case insensitive modality type parsing
+ Map<String, Object> vectorizationFields = new HashMap<>();
+
+ // Uppercase modality (use specific format)
+ Map<String, Object> imageFieldConfig1 = new HashMap<>();
+ imageFieldConfig1.put("field", "image_field");
+ imageFieldConfig1.put("modality", "JPEG");
+ vectorizationFields.put("image_vector1", imageFieldConfig1);
+
+ Map<String, Object> imageFieldConfig2 = new HashMap<>();
+ imageFieldConfig2.put("field", "image_field");
+ imageFieldConfig2.put("modality", "Png");
+ vectorizationFields.put("image_vector2", imageFieldConfig2);
+
+ Map<String, Object> videoFieldConfig = new HashMap<>();
+ videoFieldConfig.put("field", "video_field");
+ videoFieldConfig.put("modality", "MP4");
+ vectorizationFields.put("video_vector", videoFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+
+ // Should work with case insensitive modality types
+ Assertions.assertDoesNotThrow(
+ () -> {
+ EmbeddingTransform transform = new
EmbeddingTransform(config, catalogTable);
+ });
+ }
+
+ @Test
+ void testMultimodalDetectionWithWhitespace() {
+ CatalogTable catalogTable = createTestCatalogTable();
+
+ Map<String, Object> configMap = new HashMap<>();
+ configMap.put(ModelTransformConfig.MODEL_PROVIDER.key(),
ModelProvider.DOUBAO.name());
+ configMap.put(ModelTransformConfig.MODEL.key(),
"doubao-embedding-vision");
+ configMap.put(ModelTransformConfig.API_KEY.key(), "test-api-key");
+ configMap.put(ModelTransformConfig.API_PATH.key(),
"https://api.test.com/embeddings");
+
+ // Test field specifications with whitespace
+ Map<String, Object> vectorizationFields = new HashMap<>();
+ Map<String, Object> imageFieldConfig = new HashMap<>();
+ imageFieldConfig.put("field", " image_field ");
+ imageFieldConfig.put("modality", "bmp");
+ vectorizationFields.put("image_vector1", imageFieldConfig);
+
+ // Field with whitespace in modality
+ Map<String, Object> videoFieldConfig = new HashMap<>();
+ videoFieldConfig.put("field", "video_field");
+ videoFieldConfig.put("modality", " mov ");
+ vectorizationFields.put("video_vector", videoFieldConfig);
+
+ configMap.put(EmbeddingTransformConfig.VECTORIZATION_FIELDS.key(),
vectorizationFields);
+
+ ReadonlyConfig config = ReadonlyConfig.fromMap(configMap);
+ Assertions.assertDoesNotThrow(
+ () -> {
+ EmbeddingTransform transform = new
EmbeddingTransform(config, catalogTable);
+ });
+ }
+}