This is an automated email from the ASF dual-hosted git repository.
corgy pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new a40114cf7a [Feature][Transform-V2] Support vector series sql function
(#9765)
a40114cf7a is described below
commit a40114cf7a6367583b6d8a43ec13b1a11b599a7d
Author: liuwei178 <[email protected]>
AuthorDate: Fri Aug 29 09:39:23 2025 +0800
[Feature][Transform-V2] Support vector series sql function (#9765)
---
docs/en/transform-v2/sql-functions.md | 68 ++++++-
docs/zh/transform-v2/sql-functions.md | 66 +++++++
.../utils/{BufferUtils.java => VectorUtils.java} | 46 ++++-
.../{BufferUtilsTest.java => VectorUtilsTest.java} | 34 ++--
.../serialize/ElasticsearchRowSerializer.java | 6 +-
.../seatunnel/fake/utils/FakeDataRandomUtils.java | 8 +-
.../oceanbase/OceanBaseMysqlJdbcRowConverter.java | 6 +-
.../milvus/utils/sink/MilvusSinkConverter.java | 4 +-
.../milvus/utils/source/MilvusSourceConverter.java | 4 +-
.../seatunnel/qdrant/sink/QdrantBatchWriter.java | 4 +-
.../qdrant/source/QdrantSourceReader.java | 4 +-
.../e2e/connector/v2/milvus/MilvusIT.java | 4 +-
.../nlpmodel/embedding/remote/AbstractModel.java | 4 +-
.../transform/sql/zeta/ZetaSQLFunction.java | 21 +++
.../seatunnel/transform/sql/zeta/ZetaSQLType.java | 6 +
.../sql/zeta/functions/VectorFunction.java | 202 +++++++++++++++++++++
.../transform/embedding/EmbeddingVectorTest.java | 4 +-
.../transform/sql/zeta/VectorFunctionTest.java | 195 ++++++++++++++++++++
18 files changed, 643 insertions(+), 43 deletions(-)
diff --git a/docs/en/transform-v2/sql-functions.md
b/docs/en/transform-v2/sql-functions.md
index 2320a346f3..7fcaf0344c 100644
--- a/docs/en/transform-v2/sql-functions.md
+++ b/docs/en/transform-v2/sql-functions.md
@@ -1155,4 +1155,70 @@ SELECT * FROM dual
LATERAL VIEW EXPLODE ( SPLIT ( pk_id, ';' ) ) AS pk_id
LATERAL VIEW OUTER EXPLODE ( age ) AS age
LATERAL VIEW OUTER EXPLODE ( ARRAY(1,1) ) AS num
-```
\ No newline at end of file
+```
+
+## Vector Functions
+
+### VECTOR_DIMS
+
+```VECTOR_DIMS(vector) -> INT```
+
+Returns an INT value representing the number of dimensions (elements) in the
vector.
+
+Example:
+
+VECTOR_DIMS(vector)
+
+### VECTOR_NORM
+
+```VECTOR_NORM(vector) -> DOUBLE```
+
+Calculates the L2 norm (Euclidean norm) of a vector, which represents the
length or magnitude of the vector.
+
+Example:
+
+VECTOR_NORM(vector)
+
+### INNER_PRODUCT
+
+```INNER_PRODUCT(vector1, vector2) -> DOUBLE```
+
+Calculates the inner product (dot product) of two vectors, which is used to
measure the similarity and projection between the vectors.
+
+Example:
+
+INNER_PRODUCT(vector1, vector2)
+
+### COSINE_DISTANCE
+
+```COSINE_DISTANCE(vector1, vector2) -> DOUBLE```
+
+Returns a DOUBLE value between 0 and 1:
+
+0: Identical vectors (completely similar)
+
+1: Orthogonal vectors (completely dissimilar)
+
+Example:
+
+COSINE_DISTANCE(vector1, vector2)
+
+### L1_DISTANCE
+
+```L1_DISTANCE(vector1, vector2) -> DOUBLE```
+
+Calculates the Manhattan (L1) distance between two vectors.
+
+Example:
+
+L1_DISTANCE(vector1, vector2)
+
+### L2_DISTANCE
+
+```L2_DISTANCE(vector1, vector2) -> DOUBLE```
+
+Calculates the Euclidean (L2) distance between two vectors.
+
+Example:
+
+L2_DISTANCE(vector1, vector2)
\ No newline at end of file
diff --git a/docs/zh/transform-v2/sql-functions.md
b/docs/zh/transform-v2/sql-functions.md
index dc79ead98d..ad47beeb4a 100644
--- a/docs/zh/transform-v2/sql-functions.md
+++ b/docs/zh/transform-v2/sql-functions.md
@@ -1150,3 +1150,69 @@ SELECT * FROM dual
LATERAL VIEW OUTER EXPLODE ( age ) AS age
LATERAL VIEW OUTER EXPLODE ( ARRAY(1,1) ) AS num
```
+
+## 向量函数
+
+### VECTOR_DIMS
+
+```VECTOR_DIMS(vector) -> INT```
+
+返回一个INT值,表示向量中的维数(元素)。
+
+示例:
+
+VECTOR_DIMS(vector)
+
+### VECTOR_NORM
+
+```VECTOR_NORM(vector) -> DOUBLE```
+
+计算向量的L2范数(欧几里得范数),表示向量的长度或大小。
+
+示例:
+
+VECTOR_NORM(vector)
+
+### INNER_PRODUCT
+
+```INNER_PRODUCT(vector1, vector2) -> DOUBLE```
+
+计算两个向量的内积(点积),用于测量向量之间的相似性和投影。
+
+示例:
+
+INNER_PRODUCT(vector1, vector2)
+
+### COSINE_DISTANCE
+
+```COSINE_DISTANCE(vector1, vector2) -> DOUBLE```
+
+返回介于 0 和 1 之间的 DOUBLE 值:
+
+0:相同的向量(完全相似)
+
+1:正交向量(完全不同)
+
+示例:
+
+COSINE_DISTANCE(vector1, vector2)
+
+### L1_DISTANCE
+
+```L1_DISTANCE(vector1, vector2) -> DOUBLE```
+
+计算两个向量之间的曼哈顿(L1)距离。
+
+示例:
+
+L1_DISTANCE(vector1, vector2)
+
+### L2_DISTANCE
+
+```L2_DISTANCE(vector1, vector2) -> DOUBLE```
+
+计算两个向量之间的欧几里得(L2)距离。
+
+示例:
+
+L2_DISTANCE(vector1, vector2)
\ No newline at end of file
diff --git
a/seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/BufferUtils.java
b/seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/VectorUtils.java
similarity index 69%
rename from
seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/BufferUtils.java
rename to
seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/VectorUtils.java
index ab27c8521f..9d960ac882 100644
---
a/seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/BufferUtils.java
+++
b/seatunnel-common/src/main/java/org/apache/seatunnel/common/utils/VectorUtils.java
@@ -35,8 +35,10 @@ package org.apache.seatunnel.common.utils;
import java.nio.Buffer;
import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Map;
-public class BufferUtils {
+public class VectorUtils {
public static ByteBuffer toByteBuffer(Short[] shortArray) {
ByteBuffer byteBuffer = ByteBuffer.allocate(shortArray.length * 2);
@@ -127,4 +129,46 @@ public class BufferUtils {
return intArray;
}
+
+ public static Float[] convertSparseVectorToFloatArray(Map<?, ?>
sparseVector) {
+ if (sparseVector.isEmpty()) {
+ return new Float[0];
+ }
+ int maxIndex = -1;
+ for (Map.Entry<?, ?> entry : sparseVector.entrySet()) {
+ Object key = entry.getKey();
+ if (!(key instanceof Integer)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Sparse vector key must be Integer, but got:
%s,",
+ key.getClass().getName()));
+ }
+ int index = (Integer) key;
+ if (index < 0) {
+ throw new IllegalArgumentException(
+ String.format("Sparse vector index cannot be negative:
%d", index));
+ }
+ // prevent OOM
+ if (index > 1000000) {
+ throw new IllegalArgumentException(
+ String.format("Sparse vector index too large: %d",
index));
+ }
+ maxIndex = Math.max(maxIndex, index);
+ }
+ Float[] denseVector = new Float[maxIndex + 1];
+ Arrays.fill(denseVector, 0.0f);
+ for (Map.Entry<?, ?> entry : sparseVector.entrySet()) {
+ Object key = entry.getKey();
+ Object value = entry.getValue();
+ if (!(value instanceof Number)) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Sparse vector value must be a Number, but
got: %s",
+ value.getClass().getName()));
+ }
+ int index = (Integer) key;
+ denseVector[index] = ((Number) value).floatValue();
+ }
+ return denseVector;
+ }
}
diff --git
a/seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/BufferUtilsTest.java
b/seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/VectorUtilsTest.java
similarity index 71%
rename from
seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/BufferUtilsTest.java
rename to
seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/VectorUtilsTest.java
index 36a010ec91..99fa3d936f 100644
---
a/seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/BufferUtilsTest.java
+++
b/seatunnel-common/src/test/java/org/apache/seatunnel/common/utils/VectorUtilsTest.java
@@ -22,13 +22,13 @@ import org.junit.jupiter.api.Test;
import java.nio.ByteBuffer;
-public class BufferUtilsTest {
+public class VectorUtilsTest {
@Test
public void testToByteBufferAndToShortArray() {
Short[] shortArray = {1, 2, 3, 4, 5};
- ByteBuffer byteBuffer = BufferUtils.toByteBuffer(shortArray);
- Short[] resultArray = BufferUtils.toShortArray(byteBuffer);
+ ByteBuffer byteBuffer = VectorUtils.toByteBuffer(shortArray);
+ Short[] resultArray = VectorUtils.toShortArray(byteBuffer);
Assertions.assertArrayEquals(shortArray, resultArray, "Short array
conversion failed");
}
@@ -36,8 +36,8 @@ public class BufferUtilsTest {
@Test
public void testToByteBufferAndToFloatArray() {
Float[] floatArray = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f};
- ByteBuffer byteBuffer = BufferUtils.toByteBuffer(floatArray);
- Float[] resultArray = BufferUtils.toFloatArray(byteBuffer);
+ ByteBuffer byteBuffer = VectorUtils.toByteBuffer(floatArray);
+ Float[] resultArray = VectorUtils.toFloatArray(byteBuffer);
Assertions.assertArrayEquals(floatArray, resultArray, "Float array
conversion failed");
}
@@ -45,8 +45,8 @@ public class BufferUtilsTest {
@Test
public void testToByteBufferAndToDoubleArray() {
Double[] doubleArray = {1.1, 2.2, 3.3, 4.4, 5.5};
- ByteBuffer byteBuffer = BufferUtils.toByteBuffer(doubleArray);
- Double[] resultArray = BufferUtils.toDoubleArray(byteBuffer);
+ ByteBuffer byteBuffer = VectorUtils.toByteBuffer(doubleArray);
+ Double[] resultArray = VectorUtils.toDoubleArray(byteBuffer);
Assertions.assertArrayEquals(doubleArray, resultArray, "Double array
conversion failed");
}
@@ -54,8 +54,8 @@ public class BufferUtilsTest {
@Test
public void testToByteBufferAndToIntArray() {
Integer[] intArray = {1, 2, 3, 4, 5};
- ByteBuffer byteBuffer = BufferUtils.toByteBuffer(intArray);
- Integer[] resultArray = BufferUtils.toIntArray(byteBuffer);
+ ByteBuffer byteBuffer = VectorUtils.toByteBuffer(intArray);
+ Integer[] resultArray = VectorUtils.toIntArray(byteBuffer);
Assertions.assertArrayEquals(intArray, resultArray, "Integer array
conversion failed");
}
@@ -64,26 +64,26 @@ public class BufferUtilsTest {
public void testEmptyArrayConversion() {
// Test empty arrays
Short[] shortArray = {};
- ByteBuffer shortBuffer = BufferUtils.toByteBuffer(shortArray);
- Short[] shortResultArray = BufferUtils.toShortArray(shortBuffer);
+ ByteBuffer shortBuffer = VectorUtils.toByteBuffer(shortArray);
+ Short[] shortResultArray = VectorUtils.toShortArray(shortBuffer);
Assertions.assertArrayEquals(
shortArray, shortResultArray, "Empty Short array conversion
failed");
Float[] floatArray = {};
- ByteBuffer floatBuffer = BufferUtils.toByteBuffer(floatArray);
- Float[] floatResultArray = BufferUtils.toFloatArray(floatBuffer);
+ ByteBuffer floatBuffer = VectorUtils.toByteBuffer(floatArray);
+ Float[] floatResultArray = VectorUtils.toFloatArray(floatBuffer);
Assertions.assertArrayEquals(
floatArray, floatResultArray, "Empty Float array conversion
failed");
Double[] doubleArray = {};
- ByteBuffer doubleBuffer = BufferUtils.toByteBuffer(doubleArray);
- Double[] doubleResultArray = BufferUtils.toDoubleArray(doubleBuffer);
+ ByteBuffer doubleBuffer = VectorUtils.toByteBuffer(doubleArray);
+ Double[] doubleResultArray = VectorUtils.toDoubleArray(doubleBuffer);
Assertions.assertArrayEquals(
doubleArray, doubleResultArray, "Empty Double array conversion
failed");
Integer[] intArray = {};
- ByteBuffer intBuffer = BufferUtils.toByteBuffer(intArray);
- Integer[] intResultArray = BufferUtils.toIntArray(intBuffer);
+ ByteBuffer intBuffer = VectorUtils.toByteBuffer(intArray);
+ Integer[] intResultArray = VectorUtils.toIntArray(intBuffer);
Assertions.assertArrayEquals(
intArray, intResultArray, "Empty Integer array conversion
failed");
}
diff --git
a/seatunnel-connectors-v2/connector-elasticsearch/src/main/java/org/apache/seatunnel/connectors/seatunnel/elasticsearch/serialize/ElasticsearchRowSerializer.java
b/seatunnel-connectors-v2/connector-elasticsearch/src/main/java/org/apache/seatunnel/connectors/seatunnel/elasticsearch/serialize/ElasticsearchRowSerializer.java
index 20ba92c4f6..aa1e46d283 100644
---
a/seatunnel-connectors-v2/connector-elasticsearch/src/main/java/org/apache/seatunnel/connectors/seatunnel/elasticsearch/serialize/ElasticsearchRowSerializer.java
+++
b/seatunnel-connectors-v2/connector-elasticsearch/src/main/java/org/apache/seatunnel/connectors/seatunnel/elasticsearch/serialize/ElasticsearchRowSerializer.java
@@ -24,7 +24,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.exception.CommonError;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import
org.apache.seatunnel.connectors.seatunnel.elasticsearch.dto.ElasticsearchClusterInfo;
import org.apache.seatunnel.connectors.seatunnel.elasticsearch.dto.IndexInfo;
import
org.apache.seatunnel.connectors.seatunnel.elasticsearch.exception.ElasticsearchConnectorException;
@@ -218,7 +218,7 @@ public class ElasticsearchRowSerializer implements
SeaTunnelRowSerializer {
// Check if this field is configured as a vectorization field
if (vectorizationFields != null &&
vectorizationFields.contains(fieldName)) {
ByteBuffer buffer = (ByteBuffer) value;
- Float[] floats = BufferUtils.toFloatArray(buffer);
+ Float[] floats = VectorUtils.toFloatArray(buffer);
// Use the configured dimension or calculate it from the
buffer size
int dimension = vectorDimension > 0 ? vectorDimension :
buffer.remaining() / 4;
@@ -232,7 +232,7 @@ public class ElasticsearchRowSerializer implements
SeaTunnelRowSerializer {
} else {
// Default behavior for ByteBuffer fields not specified as
vectorization fields
ByteBuffer buffer = (ByteBuffer) value;
- Float[] floats = BufferUtils.toFloatArray(buffer);
+ Float[] floats = VectorUtils.toFloatArray(buffer);
int floatCount = buffer.remaining() / 4;
for (int i = 0; i < floatCount; i++) {
diff --git
a/seatunnel-connectors-v2/connector-fake/src/main/java/org/apache/seatunnel/connectors/seatunnel/fake/utils/FakeDataRandomUtils.java
b/seatunnel-connectors-v2/connector-fake/src/main/java/org/apache/seatunnel/connectors/seatunnel/fake/utils/FakeDataRandomUtils.java
index 5e9f4e809c..3603ece7fe 100644
---
a/seatunnel-connectors-v2/connector-fake/src/main/java/org/apache/seatunnel/connectors/seatunnel/fake/utils/FakeDataRandomUtils.java
+++
b/seatunnel-connectors-v2/connector-fake/src/main/java/org/apache/seatunnel/connectors/seatunnel/fake/utils/FakeDataRandomUtils.java
@@ -19,7 +19,7 @@ package org.apache.seatunnel.connectors.seatunnel.fake.utils;
import org.apache.seatunnel.api.table.catalog.Column;
import org.apache.seatunnel.api.table.type.DecimalType;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import org.apache.seatunnel.connectors.seatunnel.fake.config.FakeConfig;
import org.apache.commons.collections4.CollectionUtils;
@@ -231,7 +231,7 @@ public class FakeDataRandomUtils {
RandomUtils.nextFloat(
fakeConfig.getVectorFloatMin(),
fakeConfig.getVectorFloatMax());
}
- return BufferUtils.toByteBuffer(floatVector);
+ return VectorUtils.toByteBuffer(floatVector);
}
public ByteBuffer randomFloat16Vector(Column column) {
@@ -244,7 +244,7 @@ public class FakeDataRandomUtils {
fakeConfig.getVectorFloatMin(),
fakeConfig.getVectorFloatMax());
float16Vector[i] = floatToFloat16(value);
}
- return BufferUtils.toByteBuffer(float16Vector);
+ return VectorUtils.toByteBuffer(float16Vector);
}
public ByteBuffer randomBFloat16Vector(Column column) {
@@ -257,7 +257,7 @@ public class FakeDataRandomUtils {
fakeConfig.getVectorFloatMin(),
fakeConfig.getVectorFloatMax());
bfloat16Vector[i] = floatToBFloat16(value);
}
- return BufferUtils.toByteBuffer(bfloat16Vector);
+ return VectorUtils.toByteBuffer(bfloat16Vector);
}
public Map<Integer, Float> randomSparseFloatVector(Column column) {
diff --git
a/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/internal/dialect/oceanbase/OceanBaseMysqlJdbcRowConverter.java
b/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/internal/dialect/oceanbase/OceanBaseMysqlJdbcRowConverter.java
index 4984568e18..bf49d4d0b7 100644
---
a/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/internal/dialect/oceanbase/OceanBaseMysqlJdbcRowConverter.java
+++
b/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/internal/dialect/oceanbase/OceanBaseMysqlJdbcRowConverter.java
@@ -25,7 +25,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonError;
import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import
org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorErrorCode;
import
org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorException;
import
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.converter.AbstractJdbcRowConverter;
@@ -101,7 +101,7 @@ public class OceanBaseMysqlJdbcRowConverter extends
AbstractJdbcRowConverter {
for (int i = 0; i < stringArray.length; i++) {
arrays[i] = Float.parseFloat(stringArray[i]);
}
- fields[fieldIndex] = BufferUtils.toByteBuffer(arrays);
+ fields[fieldIndex] = VectorUtils.toByteBuffer(arrays);
}
break;
case DOUBLE:
@@ -188,7 +188,7 @@ public class OceanBaseMysqlJdbcRowConverter extends
AbstractJdbcRowConverter {
if (row.getField(fieldIndex) instanceof ByteBuffer) {
ByteBuffer byteBuffer = (ByteBuffer)
row.getField(fieldIndex);
// Convert ByteBuffer to Float[]
- Float[] floatArray =
BufferUtils.toFloatArray(byteBuffer);
+ Float[] floatArray =
VectorUtils.toFloatArray(byteBuffer);
StringBuilder vector = new StringBuilder();
vector.append("[");
for (Float aFloat : floatArray) {
diff --git
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/sink/MilvusSinkConverter.java
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/sink/MilvusSinkConverter.java
index 41c8131858..a8256ea56a 100644
---
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/sink/MilvusSinkConverter.java
+++
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/sink/MilvusSinkConverter.java
@@ -28,8 +28,8 @@ import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
-import org.apache.seatunnel.common.utils.BufferUtils;
import org.apache.seatunnel.common.utils.JsonUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import
org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectionErrorCode;
import
org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
@@ -74,7 +74,7 @@ public class MilvusSinkConverter {
return value.toString();
case FLOAT_VECTOR:
ByteBuffer floatVectorBuffer = (ByteBuffer) value;
- Float[] floats = BufferUtils.toFloatArray(floatVectorBuffer);
+ Float[] floats = VectorUtils.toFloatArray(floatVectorBuffer);
return Arrays.stream(floats).collect(Collectors.toList());
case BINARY_VECTOR:
case BFLOAT16_VECTOR:
diff --git
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/source/MilvusSourceConverter.java
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/source/MilvusSourceConverter.java
index 168b6df3ef..06063719fd 100644
---
a/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/source/MilvusSourceConverter.java
+++
b/seatunnel-connectors-v2/connector-milvus/src/main/java/org/apache/seatunnel/connectors/seatunnel/milvus/utils/source/MilvusSourceConverter.java
@@ -31,7 +31,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.api.table.type.VectorType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import
org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
import com.google.gson.Gson;
@@ -216,7 +216,7 @@ public class MilvusSourceConverter {
for (int i = 0; i < list.size(); i++) {
arrays[i] =
Float.parseFloat(list.get(i).toString());
}
- seatunnelField[fieldIndex] =
BufferUtils.toByteBuffer(arrays);
+ seatunnelField[fieldIndex] =
VectorUtils.toByteBuffer(arrays);
break;
} else {
throw new MilvusConnectorException(
diff --git
a/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/sink/QdrantBatchWriter.java
b/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/sink/QdrantBatchWriter.java
index 7ca4428c81..ed217af840 100644
---
a/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/sink/QdrantBatchWriter.java
+++
b/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/sink/QdrantBatchWriter.java
@@ -24,7 +24,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import
org.apache.seatunnel.connectors.seatunnel.qdrant.config.QdrantParameters;
import
org.apache.seatunnel.connectors.seatunnel.qdrant.exception.QdrantConnectorException;
@@ -181,7 +181,7 @@ public class QdrantBatchWriter {
case BFLOAT16_VECTOR:
case BINARY_VECTOR:
ByteBuffer floatVectorBuffer = (ByteBuffer) value;
- Float[] floats = BufferUtils.toFloatArray(floatVectorBuffer);
+ Float[] floats = VectorUtils.toFloatArray(floatVectorBuffer);
return
VectorFactory.vector(Arrays.stream(floats).collect(Collectors.toList()));
default:
return null;
diff --git
a/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/source/QdrantSourceReader.java
b/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/source/QdrantSourceReader.java
index 2c37163129..e62b462512 100644
---
a/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/source/QdrantSourceReader.java
+++
b/seatunnel-connectors-v2/connector-qdrant/src/main/java/org/apache/seatunnel/connectors/seatunnel/qdrant/source/QdrantSourceReader.java
@@ -27,7 +27,7 @@ import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.exception.CommonErrorCode;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import
org.apache.seatunnel.connectors.seatunnel.common.source.AbstractSingleSplitReader;
import
org.apache.seatunnel.connectors.seatunnel.common.source.SingleSplitReaderContext;
import
org.apache.seatunnel.connectors.seatunnel.qdrant.config.QdrantParameters;
@@ -164,7 +164,7 @@ public class QdrantSourceReader extends
AbstractSingleSplitReader<SeaTunnelRow>
List<Float> list = vector.getDataList();
Float[] vectorArray = new Float[list.size()];
list.toArray(vectorArray);
- fields[fieldIndex] = BufferUtils.toByteBuffer(vectorArray);
+ fields[fieldIndex] = VectorUtils.toByteBuffer(vectorArray);
break;
default:
throw new QdrantConnectorException(
diff --git
a/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-milvus-e2e/src/test/java/org/apache/seatunnel/e2e/connector/v2/milvus/MilvusIT.java
b/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-milvus-e2e/src/test/java/org/apache/seatunnel/e2e/connector/v2/milvus/MilvusIT.java
index af50159b43..3aed4f1455 100644
---
a/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-milvus-e2e/src/test/java/org/apache/seatunnel/e2e/connector/v2/milvus/MilvusIT.java
+++
b/seatunnel-e2e/seatunnel-connector-v2-e2e/connector-milvus-e2e/src/test/java/org/apache/seatunnel/e2e/connector/v2/milvus/MilvusIT.java
@@ -30,7 +30,7 @@ import
org.apache.seatunnel.api.table.catalog.exception.TableAlreadyExistExcepti
import org.apache.seatunnel.api.table.catalog.exception.TableNotExistException;
import org.apache.seatunnel.api.table.type.BasicType;
import org.apache.seatunnel.api.table.type.VectorType;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import org.apache.seatunnel.connectors.seatunnel.milvus.catalog.MilvusCatalog;
import
org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSinkOptions;
import org.apache.seatunnel.e2e.common.TestResource;
@@ -364,7 +364,7 @@ public class MilvusIT extends TestSuiteBase implements
TestResource {
List<Float> vector = Arrays.asList((float) i, (float) i, (float)
i, (float) i);
row.add(VECTOR_FIELD, gson.toJsonTree(vector));
Short[] shorts = {(short) i, (short) i, (short) i, (short) i};
- ByteBuffer shortByteBuffer = BufferUtils.toByteBuffer(shorts);
+ ByteBuffer shortByteBuffer = VectorUtils.toByteBuffer(shorts);
row.add(VECTOR_FIELD2, gson.toJsonTree(shortByteBuffer.array()));
ByteBuffer binaryByteBuffer = ByteBuffer.wrap(new byte[] {16});
row.add(VECTOR_FIELD3, gson.toJsonTree(binaryByteBuffer.array()));
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
index 0803dfd7ad..1994d5b51c 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/AbstractModel.java
@@ -19,7 +19,7 @@ package
org.apache.seatunnel.transform.nlpmodel.embedding.remote;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import org.apache.commons.lang3.ArrayUtils;
@@ -44,7 +44,7 @@ public abstract class AbstractModel implements Model {
List<List<Float>> vectors = batchProcess(fields,
singleVectorizedInputNumber);
for (List<Float> vector : vectors) {
- result.add(BufferUtils.toByteBuffer(vector.toArray(new Float[0])));
+ result.add(VectorUtils.toByteBuffer(vector.toArray(new Float[0])));
}
return result;
}
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
index 7967c50450..fef526d0d7 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLFunction.java
@@ -33,6 +33,7 @@ import
org.apache.seatunnel.transform.sql.zeta.functions.DateTimeFunction;
import org.apache.seatunnel.transform.sql.zeta.functions.NumericFunction;
import org.apache.seatunnel.transform.sql.zeta.functions.StringFunction;
import org.apache.seatunnel.transform.sql.zeta.functions.SystemFunction;
+import org.apache.seatunnel.transform.sql.zeta.functions.VectorFunction;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -203,6 +204,14 @@ public class ZetaSQLFunction {
public static final String TRY_CAST = "TRY_CAST";
+ // -------------------------vector functions----------------------------
+ public static final String COSINE_DISTANCE = "COSINE_DISTANCE";
+ public static final String L1_DISTANCE = "L1_DISTANCE";
+ public static final String L2_DISTANCE = "L2_DISTANCE";
+ public static final String VECTOR_DIMS = "VECTOR_DIMS";
+ public static final String VECTOR_NORM = "VECTOR_NORM";
+ public static final String INNER_PRODUCT = "INNER_PRODUCT";
+
private final SeaTunnelRowType inputRowType;
private final ZetaSQLType zetaSQLType;
@@ -598,6 +607,18 @@ public class ZetaSQLFunction {
return ArrayFunction.arrayMin(args);
case UUID:
return randomUUID().toString();
+ case COSINE_DISTANCE:
+ return VectorFunction.cosineDistance(args);
+ case L1_DISTANCE:
+ return VectorFunction.l1Distance(args);
+ case L2_DISTANCE:
+ return VectorFunction.l2Distance(args);
+ case VECTOR_DIMS:
+ return VectorFunction.vectorDims(args);
+ case VECTOR_NORM:
+ return VectorFunction.vectorNorm(args);
+ case INNER_PRODUCT:
+ return VectorFunction.innerProduct(args);
default:
for (ZetaUDF udf : udfList) {
if (udf.functionName().equalsIgnoreCase(functionName)) {
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
index 941390dbcc..067fab2481 100644
---
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/ZetaSQLType.java
@@ -368,6 +368,7 @@ public class ZetaSQLType {
case ZetaSQLFunction.WEEK:
case ZetaSQLFunction.YEAR:
case ZetaSQLFunction.SIGN:
+ case ZetaSQLFunction.VECTOR_DIMS:
return BasicType.INT_TYPE;
case ZetaSQLFunction.BIT_LENGTH:
case ZetaSQLFunction.CHAR_LENGTH:
@@ -402,6 +403,11 @@ public class ZetaSQLType {
case ZetaSQLFunction.RANDOM:
case ZetaSQLFunction.TRUNC:
case ZetaSQLFunction.TRUNCATE:
+ case ZetaSQLFunction.COSINE_DISTANCE:
+ case ZetaSQLFunction.L1_DISTANCE:
+ case ZetaSQLFunction.L2_DISTANCE:
+ case ZetaSQLFunction.VECTOR_NORM:
+ case ZetaSQLFunction.INNER_PRODUCT:
return BasicType.DOUBLE_TYPE;
case ZetaSQLFunction.ARRAY:
return ArrayFunction.castArrayTypeMapping(function,
inputRowType);
diff --git
a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
new file mode 100644
index 0000000000..7b37acdfbd
--- /dev/null
+++
b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/sql/zeta/functions/VectorFunction.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.sql.zeta.functions;
+
+import org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
+import org.apache.seatunnel.common.utils.VectorUtils;
+import org.apache.seatunnel.transform.exception.TransformException;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+public class VectorFunction {
+
+ public static Object cosineDistance(List<Object> args) {
+ if (args.size() != 2) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format(
+ "COSINE_DISTANCE() requires 2 arguments, but %d
were provided",
+ args.size()));
+ }
+ Object arg1 = args.get(0);
+ Object arg2 = args.get(1);
+ if (arg1 == null || arg2 == null) {
+ return null;
+ }
+ Float[] vector1 = convertToFloatArray(arg1);
+ Float[] vector2 = convertToFloatArray(arg2);
+ if (vector1.length != vector2.length) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.ILLEGAL_ARGUMENT,
+ String.format(
+ "Vectors must have the same dimension: %d vs %d",
+ vector1.length, vector2.length));
+ }
+ double dotProduct =
+ IntStream.range(0, vector1.length).mapToDouble(i -> vector1[i]
* vector2[i]).sum();
+ double norm1 = Arrays.stream(vector1).mapToDouble(v -> v * v).sum();
+ double norm2 = Arrays.stream(vector2).mapToDouble(v -> v * v).sum();
+ if (norm1 == 0.0 || norm2 == 0.0) {
+ return 1.0;
+ }
+ // calculate cosine similarity
+ double cosineSimilarity = dotProduct / (Math.sqrt(norm1) *
Math.sqrt(norm2));
+ return 1.0 - cosineSimilarity;
+ }
+
+ public static Object l1Distance(List<Object> args) {
+ if (args.size() != 2) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format(
+ "L1_DISTANCE() requires exactly 2 arguments, but
%d were provided",
+ args.size()));
+ }
+ Object arg1 = args.get(0);
+ Object arg2 = args.get(1);
+ if (arg1 == null || arg2 == null) {
+ return null;
+ }
+ Float[] v1 = convertToFloatArray(arg1);
+ Float[] v2 = convertToFloatArray(arg2);
+ if (v1.length != v2.length) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.ILLEGAL_ARGUMENT,
+ String.format(
+ "Vectors must have the same dimension: %d vs %d",
+ v1.length, v2.length));
+ }
+ return IntStream.range(0, v1.length).mapToDouble(i -> Math.abs(v1[i] -
v2[i])).sum();
+ }
+
+ public static Object l2Distance(List<Object> args) {
+ if (args.size() != 2) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format(
+ "L2_DISTANCE() requires exactly 2 arguments, but
%d were provided",
+ args.size()));
+ }
+ Object arg1 = args.get(0);
+ Object arg2 = args.get(1);
+ if (arg1 == null || arg2 == null) {
+ return null;
+ }
+ Float[] v1 = convertToFloatArray(arg1);
+ Float[] v2 = convertToFloatArray(arg2);
+ if (v1.length != v2.length) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.ILLEGAL_ARGUMENT,
+ String.format(
+ "Vectors must have the same dimension: %d vs %d",
+ v1.length, v2.length));
+ }
+ double sum =
+ IntStream.range(0, v1.length)
+ .mapToDouble(
+ i -> {
+ double diff = v1[i] - v2[i];
+ return diff * diff;
+ })
+ .sum();
+ return Math.sqrt(sum);
+ }
+
+ public static Object vectorDims(List<Object> args) {
+ if (args.size() != 1) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format(
+ "VECTOR_DIMS() requires exactly 1 argument, but %d
were provided",
+ args.size()));
+ }
+ Object arg = args.get(0);
+ if (arg == null) {
+ return null;
+ }
+ Float[] vector = convertToFloatArray(arg);
+ return vector.length;
+ }
+
+ public static Object vectorNorm(List<Object> args) {
+ if (args.size() != 1) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format(
+ "VECTOR_NORM() requires exactly 1 argument, but %d
were provided",
+ args.size()));
+ }
+ Object arg = args.get(0);
+ if (arg == null) {
+ return null;
+ }
+ Float[] vector = convertToFloatArray(arg);
+ return Math.sqrt(Arrays.stream(vector).mapToDouble(v -> v * v).sum());
+ }
+
+ public static Object innerProduct(List<Object> args) {
+ if (args.size() != 2) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format(
+ "INNER_PRODUCT() requires exactly 2 arguments, but
%d were provided",
+ args.size()));
+ }
+ Object arg1 = args.get(0);
+ Object arg2 = args.get(1);
+ if (arg1 == null || arg2 == null) {
+ return null;
+ }
+ Float[] v1 = convertToFloatArray(arg1);
+ Float[] v2 = convertToFloatArray(arg2);
+ if (v1.length != v2.length) {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.ILLEGAL_ARGUMENT,
+ String.format(
+ "Vectors must have the same dimension: %d vs %d",
+ v1.length, v2.length));
+ }
+
+ return IntStream.range(0, v1.length).mapToDouble(i -> v1[i] *
v2[i]).sum();
+ }
+
+ private static Float[] convertToFloatArray(Object obj) {
+ if (obj instanceof ByteBuffer) {
+ return VectorUtils.toFloatArray((ByteBuffer) obj);
+ } else if (obj instanceof Float[]) {
+ return (Float[]) obj;
+ } else if (obj instanceof float[]) {
+ float[] primitiveArray = (float[]) obj;
+ Float[] wrapperArray = new Float[primitiveArray.length];
+ for (int i = 0; i < primitiveArray.length; i++) {
+ wrapperArray[i] = primitiveArray[i];
+ }
+ return wrapperArray;
+ } else if (obj instanceof Map) {
+ return VectorUtils.convertSparseVectorToFloatArray((Map<?, ?>)
obj);
+ } else {
+ throw new TransformException(
+ CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
+ String.format("Unsupported vector type: %s",
obj.getClass().getName()));
+ }
+ }
+}
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
index 0813de6588..9c496965bd 100644
---
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/EmbeddingVectorTest.java
@@ -22,7 +22,7 @@ import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import
org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
-import org.apache.seatunnel.common.utils.BufferUtils;
+import org.apache.seatunnel.common.utils.VectorUtils;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;
import org.junit.jupiter.api.Assertions;
@@ -113,7 +113,7 @@ public class EmbeddingVectorTest {
Object[] inputFields = {"test input"};
List<ByteBuffer> result = model.vectorization(inputFields);
ByteBuffer buffer = result.get(0);
- Float[] embedding = BufferUtils.toFloatArray(buffer);
+ Float[] embedding = VectorUtils.toFloatArray(buffer);
Assertions.assertEquals(4, embedding.length);
Assertions.assertEquals(-0.0069292835f, embedding[0]);
Assertions.assertEquals(-0.005336422f, embedding[1]);
diff --git
a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/zeta/VectorFunctionTest.java
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/zeta/VectorFunctionTest.java
new file mode 100644
index 0000000000..b27aaea163
--- /dev/null
+++
b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/sql/zeta/VectorFunctionTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.transform.sql.zeta;
+
+import org.apache.seatunnel.shade.com.google.common.collect.Maps;
+
+import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
+import org.apache.seatunnel.api.table.type.VectorType;
+import org.apache.seatunnel.common.utils.VectorUtils;
+import org.apache.seatunnel.transform.sql.SQLEngine;
+import org.apache.seatunnel.transform.sql.SQLEngineFactory;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.HashMap;
+
+public class VectorFunctionTest {
+
+ @Test
+ public void testCosineDistanceFunction() {
+
+ SQLEngine sqlEngine =
SQLEngineFactory.getSQLEngine(SQLEngineFactory.EngineType.ZETA);
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"vector_float1", "vector_float2"},
+ new SeaTunnelDataType[] {
+ VectorType.VECTOR_FLOAT_TYPE,
VectorType.VECTOR_SPARSE_FLOAT_TYPE
+ });
+ SeaTunnelRow inputRow =
+ new SeaTunnelRow(
+ new Object[] {
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
3.0f}),
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
3.0f})
+ });
+
+ sqlEngine.init(
+ "test",
+ null,
+ rowType,
+ "select COSINE_DISTANCE(vector_float1, vector_float2) as
cosineDistance from test");
+ SeaTunnelRow outRow = sqlEngine.transformBySQL(inputRow,
rowType).get(0);
+ Object f1Object = outRow.getField(0);
+ Assertions.assertEquals(0.0, f1Object);
+ }
+
+ @Test
+ public void testL1DistanceFunction() {
+
+ SQLEngine sqlEngine =
SQLEngineFactory.getSQLEngine(SQLEngineFactory.EngineType.ZETA);
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"vector_float1", "vector_float2"},
+ new SeaTunnelDataType[] {
+ VectorType.VECTOR_FLOAT_TYPE,
VectorType.VECTOR_FLOAT_TYPE
+ });
+ HashMap<Integer, Float> sparseVector = Maps.newHashMap();
+ sparseVector.put(0, 1.0f);
+ sparseVector.put(1, 2.0f);
+ sparseVector.put(2, 3.0f);
+ SeaTunnelRow inputRow =
+ new SeaTunnelRow(
+ new Object[] {
+ VectorUtils.toByteBuffer(new Float[] {2.0f, 4.0f,
6.0f}), sparseVector
+ });
+
+ sqlEngine.init(
+ "test",
+ null,
+ rowType,
+ "select L1_DISTANCE(vector_float1, vector_float2) as
l1Distance from test");
+ SeaTunnelRow outRow = sqlEngine.transformBySQL(inputRow,
rowType).get(0);
+ Object f1Object = outRow.getField(0);
+ Assertions.assertEquals(6.0, f1Object);
+ }
+
+ @Test
+ public void testL2DistanceFunction() {
+
+ SQLEngine sqlEngine =
SQLEngineFactory.getSQLEngine(SQLEngineFactory.EngineType.ZETA);
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"vector_float1", "vector_float2"},
+ new SeaTunnelDataType[] {
+ VectorType.VECTOR_FLOAT_TYPE,
VectorType.VECTOR_FLOAT_TYPE
+ });
+
+ SeaTunnelRow inputRow =
+ new SeaTunnelRow(
+ new Object[] {
+ VectorUtils.toByteBuffer(new Float[] {2.0f, 4.0f,
4.0f}),
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
2.0f})
+ });
+
+ sqlEngine.init(
+ "test",
+ null,
+ rowType,
+ "select L2_DISTANCE(vector_float1, vector_float2) as
l2Distance from test");
+ SeaTunnelRow outRow = sqlEngine.transformBySQL(inputRow,
rowType).get(0);
+ Object f1Object = outRow.getField(0);
+ Assertions.assertEquals(3.0, f1Object);
+ }
+
+ @Test
+ public void testVectorNormFunction() {
+
+ SQLEngine sqlEngine =
SQLEngineFactory.getSQLEngine(SQLEngineFactory.EngineType.ZETA);
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"vector_float1", "vector_float2"},
+ new SeaTunnelDataType[] {
+ VectorType.VECTOR_FLOAT_TYPE,
VectorType.VECTOR_FLOAT_TYPE
+ });
+
+ SeaTunnelRow inputRow =
+ new SeaTunnelRow(
+ new Object[] {
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
2.0f}),
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
3.0f})
+ });
+
+ sqlEngine.init(
+ "test", null, rowType, "select VECTOR_NORM(vector_float1) as
norm from test");
+ SeaTunnelRow outRow = sqlEngine.transformBySQL(inputRow,
rowType).get(0);
+ Object f1Object = outRow.getField(0);
+ Assertions.assertEquals(3.0, f1Object);
+ }
+
+ @Test
+ public void testVectorDimsFunction() {
+
+ SQLEngine sqlEngine =
SQLEngineFactory.getSQLEngine(SQLEngineFactory.EngineType.ZETA);
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"vector_float1"},
+ new SeaTunnelDataType[]
{VectorType.VECTOR_FLOAT_TYPE});
+
+ SeaTunnelRow inputRow =
+ new SeaTunnelRow(
+ new Object[] {
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
3.0f}),
+ });
+
+ sqlEngine.init("test", null, rowType, "select
VECTOR_DIMS(vector_float1) as dim from test");
+ SeaTunnelRow outRow = sqlEngine.transformBySQL(inputRow,
rowType).get(0);
+ Object f1Object = outRow.getField(0);
+ Assertions.assertEquals(3, f1Object);
+ }
+
+ @Test
+ public void testInnerProductFunction() {
+
+ SQLEngine sqlEngine =
SQLEngineFactory.getSQLEngine(SQLEngineFactory.EngineType.ZETA);
+ SeaTunnelRowType rowType =
+ new SeaTunnelRowType(
+ new String[] {"vector_float1", "vector_float2"},
+ new SeaTunnelDataType[] {
+ VectorType.VECTOR_FLOAT_TYPE,
VectorType.VECTOR_FLOAT_TYPE
+ });
+
+ SeaTunnelRow inputRow =
+ new SeaTunnelRow(
+ new Object[] {
+ VectorUtils.toByteBuffer(new Float[] {1.0f, 2.0f,
3.0f}),
+ VectorUtils.toByteBuffer(new Float[] {7.0f, 8.0f,
9.0f})
+ });
+
+ sqlEngine.init(
+ "test",
+ null,
+ rowType,
+ "select INNER_PRODUCT(vector_float1, vector_float2) as
innerProduct from test");
+ SeaTunnelRow outRow = sqlEngine.transformBySQL(inputRow,
rowType).get(0);
+ Object f1Object = outRow.getField(0);
+ Assertions.assertEquals(50.0, f1Object);
+ }
+}