This is an automated email from the ASF dual-hosted git repository.
gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 287a739510 [javaudf](string) Fix string format in java udf (#13854)
287a739510 is described below
commit 287a739510dce7cb3204a94c1b48eb8f6baaebc2
Author: Gabriel <[email protected]>
AuthorDate: Tue Nov 1 21:25:12 2022 +0800
[javaudf](string) Fix string format in java udf (#13854)
---
.../aggregate_function_java_udaf.h | 4 ++
be/src/vec/functions/function_java_udf.cpp | 20 +++---
.../java/org/apache/doris/udf/UdafExecutor.java | 13 +---
.../java/org/apache/doris/udf/UdfExecutor.java | 21 ++-----
.../main/java/org/apache/doris/udf/UdfUtils.java | 1 -
.../java/org/apache/doris/udf/UdfExecutorTest.java | 15 ++---
.../data/javaudf_p0/test_javaudf_string.out | 37 +++++++++++
.../main/java/org/apache/doris/udf/StringTest.java | 27 ++++++++
.../suites/javaudf_p0/test_javaudf_string.groovy | 71 ++++++++++++++++++++++
9 files changed, 163 insertions(+), 46 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
index 4c86d3b996..f1f01edce9 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
@@ -107,6 +107,10 @@ public:
RETURN_IF_ERROR(jni_frame.push(env));
RETURN_IF_ERROR(SerializeThriftMsg(env, &ctor_params,
&ctor_params_bytes));
executor_obj = env->NewObject(executor_cl, executor_ctor_id,
ctor_params_bytes);
+
+ jbyte* pBytes = env->GetByteArrayElements(ctor_params_bytes,
nullptr);
+ env->ReleaseByteArrayElements(ctor_params_bytes, pBytes,
JNI_ABORT);
+ env->DeleteLocalRef(ctor_params_bytes);
}
RETURN_ERROR_IF_EXC(env);
RETURN_IF_ERROR(JniUtil::LocalToGlobalRef(env, executor_obj,
&executor_obj));
diff --git a/be/src/vec/functions/function_java_udf.cpp
b/be/src/vec/functions/function_java_udf.cpp
index ddb5b2cf17..467e9e3c66 100644
--- a/be/src/vec/functions/function_java_udf.cpp
+++ b/be/src/vec/functions/function_java_udf.cpp
@@ -93,6 +93,10 @@ Status JavaFunctionCall::prepare(FunctionContext* context,
RETURN_IF_ERROR(SerializeThriftMsg(env, &ctor_params,
&ctor_params_bytes));
jni_ctx->executor = env->NewObject(executor_cl_, executor_ctor_id_,
ctor_params_bytes);
+
+ jbyte* pBytes = env->GetByteArrayElements(ctor_params_bytes, nullptr);
+ env->ReleaseByteArrayElements(ctor_params_bytes, pBytes, JNI_ABORT);
+ env->DeleteLocalRef(ctor_params_bytes);
}
RETURN_ERROR_IF_EXC(env);
RETURN_IF_ERROR(JniUtil::LocalToGlobalRef(env, jni_ctx->executor,
&jni_ctx->executor));
@@ -108,17 +112,17 @@ Status JavaFunctionCall::execute(FunctionContext*
context, Block& block,
JniContext* jni_ctx = reinterpret_cast<JniContext*>(
context->get_function_state(FunctionContext::THREAD_LOCAL));
int arg_idx = 0;
+ ColumnPtr args[arguments.size()];
for (size_t col_idx : arguments) {
ColumnWithTypeAndName& column = block.get_by_position(col_idx);
- auto col = column.column->convert_to_full_column_if_const();
+ args[arg_idx] = column.column->convert_to_full_column_if_const();
if (!_argument_types[arg_idx]->equals(*column.type)) {
return Status::InvalidArgument(strings::Substitute(
"$0-th input column's type $1 does not equal to required
type $2", arg_idx,
column.type->get_name(),
_argument_types[arg_idx]->get_name()));
}
- auto data_col = col;
- if (auto* nullable = check_and_get_column<const ColumnNullable>(*col))
{
- data_col = nullable->get_nested_column_ptr();
+ if (auto* nullable = check_and_get_column<const
ColumnNullable>(*args[arg_idx])) {
+ args[arg_idx] = nullable->get_nested_column_ptr();
auto null_col =
check_and_get_column<ColumnVector<UInt8>>(nullable->get_null_map_column_ptr());
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] =
@@ -127,15 +131,15 @@ Status JavaFunctionCall::execute(FunctionContext*
context, Block& block,
jni_ctx->input_nulls_buffer_ptr.get()[arg_idx] = -1;
}
- if (data_col->is_column_string()) {
- const ColumnString* str_col = assert_cast<const
ColumnString*>(data_col.get());
+ if (args[arg_idx]->is_column_string()) {
+ const ColumnString* str_col = assert_cast<const
ColumnString*>(args[arg_idx].get());
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_chars().data());
jni_ctx->input_offsets_ptrs.get()[arg_idx] =
reinterpret_cast<int64_t>(str_col->get_offsets().data());
- } else if (data_col->is_numeric() || data_col->is_column_decimal()) {
+ } else if (args[arg_idx]->is_numeric() ||
args[arg_idx]->is_column_decimal()) {
jni_ctx->input_values_buffer_ptr.get()[arg_idx] =
- reinterpret_cast<int64_t>(data_col->get_raw_data().data);
+
reinterpret_cast<int64_t>(args[arg_idx]->get_raw_data().data);
} else {
return Status::InvalidArgument(
strings::Substitute("Java UDF doesn't support type $0 now
!",
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index f944cfe2d1..b0ec36882d 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -231,14 +231,7 @@ public class UdafExecutor {
private boolean storeUdfResult(Object obj, long row) throws
UdfRuntimeException {
if (obj == null) {
- //if result is null, because we have insert default before, so
return true directly when row == 0
- //others because we hava resize the buffer, so maybe be insert
value is not correct
- if (row != 0) {
- long offset = Integer.toUnsignedLong(
- UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row));
- UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + offset - 1,
- UdfUtils.END_OF_STRING);
- }
+ // If result is null, return true directly when row == 0 as we
have already inserted default value.
return true;
}
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
@@ -343,12 +336,10 @@ public class UdafExecutor {
return false;
}
offset += bytes.length;
- UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr) + offset - 1,
- UdfUtils.END_OF_STRING);
UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * row,
Integer.parseUnsignedInt(String.valueOf(offset)));
UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
offset - bytes.length - 1, bytes.length);
+ UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
offset - bytes.length, bytes.length);
return true;
default:
throw new UdfRuntimeException("Unsupported return type: " +
retType);
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
index 2eed3f0221..fe35aec515 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
@@ -304,13 +304,6 @@ public class UdfExecutor {
assert (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1);
UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null,
outputNullPtr) + row, (byte) 1);
if (retType.equals(JavaUdfDataType.STRING)) {
- long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr);
- if (outputOffset + 1 > bufferSize) {
- return false;
- }
- outputOffset += 1;
- UdfUtils.UNSAFE.putChar(null, UdfUtils.UNSAFE.getLong(null,
outputBufferPtr)
- + outputOffset - 1, UdfUtils.END_OF_STRING);
UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr)
+ 4L * row,
Integer.parseUnsignedInt(String.valueOf(outputOffset)));
}
@@ -412,16 +405,14 @@ public class UdfExecutor {
case STRING: {
long bufferSize = UdfUtils.UNSAFE.getLong(null,
outputIntermediateStatePtr);
byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- if (outputOffset + bytes.length + 1 > bufferSize) {
+ if (outputOffset + bytes.length > bufferSize) {
return false;
}
- outputOffset += (bytes.length + 1);
- UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null,
outputBufferPtr)
- + outputOffset - 1, UdfUtils.END_OF_STRING);
+ outputOffset += bytes.length;
UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null,
outputOffsetsPtr) + 4L * row,
Integer.parseUnsignedInt(String.valueOf(outputOffset)));
UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
outputOffset - bytes.length - 1, bytes.length);
+ UdfUtils.UNSAFE.getLong(null, outputBufferPtr) +
outputOffset - bytes.length, bytes.length);
return true;
}
default:
@@ -501,13 +492,13 @@ public class UdfExecutor {
case STRING: {
long offset =
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i))
+ 4L * row));
- long numBytes = row == 0 ? offset - 1 : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
+ long numBytes = row == 0 ? offset : offset -
Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
UdfUtils.UNSAFE.getLong(null,
-
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))) - 1;
+
UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
long base =
row == 0 ? UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
UdfUtils.UNSAFE.getLong(null,
UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + offset - numBytes - 1;
+ + offset - numBytes;
byte[] bytes = new byte[(int) numBytes];
UdfUtils.copyMemory(null, base, bytes,
UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
inputObjects[i] = new String(bytes,
StandardCharsets.UTF_8);
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
index cb402b4802..dbca62a64e 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
@@ -44,7 +44,6 @@ public class UdfUtils {
public static final Unsafe UNSAFE;
private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
public static final long BYTE_ARRAY_OFFSET;
- public static final char END_OF_STRING = '\0';
static {
UNSAFE = (Unsafe) AccessController.doPrivileged(
diff --git
a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
index f48df0d0c4..7330c9b083 100644
--- a/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
+++ b/fe/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java
@@ -391,10 +391,10 @@ public class UdfExecutorTest {
for (int i = 0; i < batchSize; i++) {
input1[i] = "Input1_" + i;
input2[i] = "Input2_" + i;
- inputOffsets1[i] = i == 0 ?
input1[i].getBytes(StandardCharsets.UTF_8).length + 1
- : inputOffsets1[i - 1] +
input1[i].getBytes(StandardCharsets.UTF_8).length + 1;
- inputOffsets2[i] = i == 0 ?
input2[i].getBytes(StandardCharsets.UTF_8).length + 1
- : inputOffsets2[i - 1] +
input2[i].getBytes(StandardCharsets.UTF_8).length + 1;
+ inputOffsets1[i] = i == 0 ?
input1[i].getBytes(StandardCharsets.UTF_8).length
+ : inputOffsets1[i - 1] +
input1[i].getBytes(StandardCharsets.UTF_8).length;
+ inputOffsets2[i] = i == 0 ?
input2[i].getBytes(StandardCharsets.UTF_8).length
+ : inputOffsets2[i - 1] +
input2[i].getBytes(StandardCharsets.UTF_8).length;
inputBufferSize1 +=
input1[i].getBytes(StandardCharsets.UTF_8).length;
inputBufferSize2 +=
input2[i].getBytes(StandardCharsets.UTF_8).length;
}
@@ -453,11 +453,6 @@ public class UdfExecutorTest {
Integer.parseUnsignedInt(String.valueOf(inputOffsets1[i])));
UdfUtils.UNSAFE.putInt(null, inputOffset2 + 4L * i,
Integer.parseUnsignedInt(String.valueOf(inputOffsets2[i])));
- UdfUtils.UNSAFE.putChar(null, inputBuffer1 + inputOffsets1[i] - 1,
- UdfUtils.END_OF_STRING);
- UdfUtils.UNSAFE.putChar(null, inputBuffer2 + inputOffsets2[i] - 1,
- UdfUtils.END_OF_STRING);
-
}
params.setInputBufferPtrs(inputBufferPtr);
params.setInputNullsPtrs(inputNullPtr);
@@ -483,9 +478,7 @@ public class UdfExecutorTest {
UdfUtils.copyMemory(null, outputBuffer + lastOffset, bytes,
UdfUtils.BYTE_ARRAY_OFFSET,
bytes.length);
}
- long curOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 *
i);
assert (new String(bytes, StandardCharsets.UTF_8).equals(input1[i]
+ input2[i]));
- assert (UdfUtils.UNSAFE.getByte(null, outputBuffer + curOffset -
1) == UdfUtils.END_OF_STRING);
assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0);
}
}
diff --git a/regression-test/data/javaudf_p0/test_javaudf_string.out
b/regression-test/data/javaudf_p0/test_javaudf_string.out
new file mode 100644
index 0000000000..60c8689fca
--- /dev/null
+++ b/regression-test/data/javaudf_p0/test_javaudf_string.out
@@ -0,0 +1,37 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select_default --
+1 abcdefg1
+2 abcdefg2
+3 abcdefg3
+4 abcdefg4
+5 abcdefg5
+6 abcdefg6
+7 abcdefg7
+8 abcdefg8
+9 abcdefg9
+10 abcdefg10
+
+-- !select --
+ab****g10
+ab***fg1
+ab***fg2
+ab***fg3
+ab***fg4
+ab***fg5
+ab***fg6
+ab***fg7
+ab***fg8
+ab***fg9
+
+-- !select --
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+ab*def ab**efg
+
diff --git
a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/StringTest.java
b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/StringTest.java
new file mode 100644
index 0000000000..cc1a6a2bca
--- /dev/null
+++
b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/StringTest.java
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.udf;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class StringTest extends UDF {
+ public String evaluate(String field, Integer a, Integer b) {
+ return field.substring(0, a) + StringUtils.repeat("*", field.length()
- a -b) + field.substring(field.length()-b);
+ }
+}
diff --git a/regression-test/suites/javaudf_p0/test_javaudf_string.groovy
b/regression-test/suites/javaudf_p0/test_javaudf_string.groovy
new file mode 100644
index 0000000000..32977e6b85
--- /dev/null
+++ b/regression-test/suites/javaudf_p0/test_javaudf_string.groovy
@@ -0,0 +1,71 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import org.codehaus.groovy.runtime.IOGroovyMethods
+
+import java.nio.charset.StandardCharsets
+import java.nio.file.Files
+import java.nio.file.Paths
+
+suite("test_javaudf_string") {
+ def tableName = "test_javaudf_string"
+ def jarPath =
"""${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar"""
+
+ log.info("Jar path: ${jarPath}".toString())
+ try {
+ sql """ DROP TABLE IF EXISTS ${tableName} """
+ sql """
+ CREATE TABLE IF NOT EXISTS ${tableName} (
+ `user_id` INT NOT NULL COMMENT "用户id",
+ `string_col` VARCHAR(10) NOT NULL COMMENT "用户id"
+ )
+ DISTRIBUTED BY HASH(user_id) PROPERTIES("replication_num" = "1");
+ """
+ StringBuilder sb = new StringBuilder()
+ int i = 1
+ for (; i < 10; i ++) {
+ sb.append("""
+ (${i}, 'abcdefg${i}'),
+ """)
+ }
+ sb.append("""
+ (${i}, 'abcdefg${i}')
+ """)
+ sql """ INSERT INTO ${tableName} VALUES
+ ${sb.toString()}
+ """
+ qt_select_default """ SELECT * FROM ${tableName} t ORDER BY user_id;
"""
+
+ File path = new File(jarPath)
+ if (!path.exists()) {
+ throw new IllegalStateException("""${jarPath} doesn't exist! """)
+ }
+
+ sql """ CREATE FUNCTION java_udf_string_test(string, int, int) RETURNS
string PROPERTIES (
+ "file"="file://${jarPath}",
+ "symbol"="org.apache.doris.udf.StringTest",
+ "type"="JAVA_UDF"
+ ); """
+
+ qt_select """ SELECT java_udf_string_test(string_col, 2, 3) result
FROM ${tableName} ORDER BY result; """
+ qt_select """ SELECT java_udf_string_test('abcdef', 2, 3),
java_udf_string_test('abcdefg', 2, 3) result FROM ${tableName} ORDER BY result;
"""
+
+ sql """ DROP FUNCTION java_udf_string_test(string, int, int); """
+ } finally {
+ try_sql("DROP TABLE IF EXISTS ${tableName}")
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]