This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new be909b6475 [GLUTEN-8836][CH] Support partition values with escape char
(#8840)
be909b6475 is described below
commit be909b647502272f116cae0f07d811778b2a2539
Author: Wenzheng Liu <[email protected]>
AuthorDate: Wed Mar 5 14:03:21 2025 +0800
[GLUTEN-8836][CH] Support partition values with escape char (#8840)
---
.../execution/GlutenMergeTreePartition.scala | 22 ++-
.../delta/files/MergeTreeFileCommitProtocol.scala | 2 +-
.../v2/clickhouse/metadata/AddFileTags.scala | 3 +-
.../GlutenClickHouseNativeWriteTableSuite.scala | 160 ++++++++-------------
...GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala | 51 ++++++-
.../apache/spark/gluten/NativeWriteChecker.scala | 85 +++++++++++
.../Functions/SparkPartitionEscape.cpp | 109 ++++++++++++++
.../local-engine/Functions/SparkPartitionEscape.h | 57 ++++++++
.../CommonScalarFunctionParser.cpp | 1 +
.../Storages/MergeTree/SparkMergeTreeMeta.cpp | 4 +-
.../Storages/Output/NormalFileWriter.h | 14 +-
cpp-ch/local-engine/tests/CMakeLists.txt | 1 +
.../benchmark_spark_partition_escape_function.cpp | 53 +++++++
13 files changed, 453 insertions(+), 109 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
index a4394740f8..cc7114dd72 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/GlutenMergeTreePartition.scala
@@ -19,6 +19,10 @@ package org.apache.gluten.execution
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.types.StructType
+import org.apache.hadoop.fs.Path
+
+import java.net.URI
+
case class MergeTreePartRange(
name: String,
dirName: String,
@@ -32,7 +36,7 @@ case class MergeTreePartRange(
}
}
-case class MergeTreePartSplit(
+case class MergeTreePartSplit private (
name: String,
dirName: String,
targetNode: String,
@@ -44,6 +48,22 @@ case class MergeTreePartSplit(
}
}
+object MergeTreePartSplit {
+ def apply(
+ name: String,
+ dirName: String,
+ targetNode: String,
+ start: Long,
+ length: Long,
+ bytesOnDisk: Long
+ ): MergeTreePartSplit = {
+ // Ref to org.apache.spark.sql.delta.files.TahoeFileIndex.absolutePath
+ val uriDecodeName = new Path(new URI(name)).toString
+ val uriDecodeDirName = new Path(new URI(dirName)).toString
+ new MergeTreePartSplit(uriDecodeName, uriDecodeDirName, targetNode, start,
length, bytesOnDisk)
+ }
+}
+
case class GlutenMergeTreePartition(
index: Int,
engine: String,
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
index 13a9efa359..a8d572c93f 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/files/MergeTreeFileCommitProtocol.scala
@@ -52,7 +52,7 @@ trait MergeTreeFileCommitProtocol extends FileCommitProtocol {
dir: Option[String],
ext: String): String = {
- val partitionStr = dir.map(p => new Path(p).toUri.toString)
+ val partitionStr = dir.map(p => new Path(p).toString)
val bucketIdStr =
ext.split("\\.").headOption.filter(_.startsWith("_")).map(_.substring(1))
val split = taskContext.getTaskAttemptID.getTaskID.getId
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
index c4c971633a..df79b161cf 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v2/clickhouse/metadata/AddFileTags.scala
@@ -152,7 +152,8 @@ object AddFileTags {
rootNode.put("nullCount", "")
// Add the `stats` into delta meta log
val metricsStats = mapper.writeValueAsString(rootNode)
- AddFile(name, partitionValues, bytesOnDisk, modificationTime, dataChange,
metricsStats, tags)
+ val uriName = new Path(name).toUri.toString
+ AddFile(uriName, partitionValues, bytesOnDisk, modificationTime,
dataChange, metricsStats, tags)
}
def addFileToAddMergeTreeParts(addFile: AddFile): AddMergeTreeParts = {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
index 1ee0b18b11..7e8ca2236c 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala
@@ -23,12 +23,14 @@ import
org.apache.gluten.test.AllDataTypesWithComplexType.genTestData
import org.apache.spark.SparkConf
import org.apache.spark.gluten.NativeWriteChecker
+import org.apache.spark.sql.Row
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig
import org.apache.spark.sql.types._
-import scala.reflect.runtime.universe.TypeTag
+import java.io.File
+import java.sql.Date
class GlutenClickHouseNativeWriteTableSuite
extends GlutenClickHouseWholeStageTransformerSuite
@@ -67,12 +69,6 @@ class GlutenClickHouseNativeWriteTableSuite
.setMaster("local[1]")
}
- private def getWarehouseDir = {
- // test non-ascii path, by the way
- // scalastyle:off nonascii
- basePath + "/中文/spark-warehouse"
- }
-
private val table_name_template = "hive_%s_test"
private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla"
@@ -81,58 +77,7 @@ class GlutenClickHouseNativeWriteTableSuite
super.afterAll()
}
- def getColumnName(s: String): String = {
- s.replaceAll("\\(", "_").replaceAll("\\)", "_")
- }
-
import collection.immutable.ListMap
-
- import java.io.File
-
- def compareSource(original_table: String, table_name: String, fields:
Seq[String]): Unit = {
- val rowsFromOriginTable =
- spark.sql(s"select ${fields.mkString(",")} from
$original_table").collect()
- val dfFromWriteTable =
- spark.sql(
- s"select " +
- s"${fields
- .map(getColumnName)
- .mkString(",")} " +
- s"from $table_name")
- checkAnswer(dfFromWriteTable, rowsFromOriginTable)
- }
- def writeAndCheckRead(
- original_table: String,
- table_name: String,
- fields: Seq[String],
- checkNative: Boolean = true)(write: Seq[String] => Unit): Unit =
- withDestinationTable(table_name) {
- withNativeWriteCheck(checkNative) {
- write(fields)
- }
- compareSource(original_table, table_name, fields)
- }
-
- def recursiveListFiles(f: File): Array[File] = {
- val these = f.listFiles
- these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
- }
-
- def getSignature(format: String, filesOfNativeWriter: Array[File]):
Array[(Long, Long)] = {
- filesOfNativeWriter.map(
- f => {
- val df = if (format.equals("parquet")) {
- spark.read.parquet(f.getAbsolutePath)
- } else {
- spark.read.orc(f.getAbsolutePath)
- }
- (
- df.count(),
- df.agg(("int_field",
"sum")).collect().apply(0).apply(0).asInstanceOf[Long]
- )
- })
- }
-
private val fields_ = ListMap(
("string_field", "string"),
("int_field", "int"),
@@ -146,22 +91,6 @@ class GlutenClickHouseNativeWriteTableSuite
("date_field", "date")
)
- def nativeWrite2(
- f: String => (String, String, String),
- extraCheck: (String, String) => Unit = null,
- checkNative: Boolean = true): Unit = nativeWrite {
- format =>
- val (table_name, table_create_sql, insert_sql) = f(format)
- withDestinationTable(table_name, Option(table_create_sql)) {
- checkInsertQuery(insert_sql, checkNative)
- Option(extraCheck).foreach(_(table_name, format))
- }
- }
-
- def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs:
(String, String)*)(
- block: => Unit): Unit =
- withSource(spark.createDataFrame(data), viewName, pairs: _*)(block)
-
private lazy val supplierSchema = StructType.apply(
Seq(
StructField.apply("s_suppkey", LongType, nullable = true),
@@ -618,18 +547,7 @@ class GlutenClickHouseNativeWriteTableSuite
.saveAsTable(table_name_vanilla)
}
}
- val sigsOfNativeWriter =
- getSignature(
- format,
- recursiveListFiles(new File(getWarehouseDir + "/" + table_name))
- .filter(_.getName.endsWith(s".$format"))).sorted
- val sigsOfVanillaWriter =
- getSignature(
- format,
- recursiveListFiles(new File(getWarehouseDir + "/" +
table_name_vanilla))
- .filter(_.getName.endsWith(s".$format"))).sorted
-
- assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
+ compareWriteFilesSignature(format, table_name, table_name_vanilla,
"sum(int_field)")
}
}
}
@@ -680,18 +598,7 @@ class GlutenClickHouseNativeWriteTableSuite
.bucketBy(10, "byte_field", "string_field")
.saveAsTable(table_name_vanilla)
}
- val sigsOfNativeWriter =
- getSignature(
- format,
- recursiveListFiles(new File(getWarehouseDir + "/" +
table_name))
- .filter(_.getName.endsWith(s".$format"))).sorted
- val sigsOfVanillaWriter =
- getSignature(
- format,
- recursiveListFiles(new File(getWarehouseDir + "/" +
table_name_vanilla))
- .filter(_.getName.endsWith(s".$format"))).sorted
-
- assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
+ compareWriteFilesSignature(format, table_name, table_name_vanilla,
"sum(int_field)")
}
}
}
@@ -754,6 +661,63 @@ class GlutenClickHouseNativeWriteTableSuite
}
}
+ test("test partitioned with escaped characters") {
+
+ val schema = StructType(
+ Seq(
+ StructField.apply("id", IntegerType, nullable = true),
+ StructField.apply("escape", StringType, nullable = true),
+ StructField.apply("bucket/col", StringType, nullable = true),
+ StructField.apply("part=col1", DateType, nullable = true),
+ StructField.apply("part_col2", StringType, nullable = true)
+ ))
+
+ val data: Seq[Row] = Seq(
+ Row(1, "=", "00000", Date.valueOf("2024-01-01"), "2024=01/01"),
+ Row(2, "/", "00000", Date.valueOf("2024-01-01"), "2024=01/01"),
+ Row(3, "#", "00000", Date.valueOf("2024-01-01"), "2024#01:01"),
+ Row(4, ":", "00001", Date.valueOf("2024-01-02"), "2024#01:01"),
+ Row(5, "\\", "00001", Date.valueOf("2024-01-02"), "2024\\01\u000101"),
+ Row(6, "\u0001", "000001", Date.valueOf("2024-01-02"),
"2024\\01\u000101"),
+ Row(7, "", "000002", null, null)
+ )
+
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data),
schema)
+ df.createOrReplaceTempView("origin_table")
+ spark.sql("select * from origin_table").show()
+
+ nativeWrite {
+ format =>
+ val table_name = table_name_template.format(format)
+ spark.sql(s"drop table IF EXISTS $table_name")
+ writeAndCheckRead("origin_table", table_name, schema.fieldNames.map(f
=> s"`$f`")) {
+ _ =>
+ spark
+ .table("origin_table")
+ .write
+ .format(format)
+ .partitionBy("part=col1", "part_col2")
+ .bucketBy(2, "bucket/col")
+ .saveAsTable(table_name)
+ }
+
+ val table_name_vanilla = table_name_vanilla_template.format(format)
+ spark.sql(s"drop table IF EXISTS $table_name_vanilla")
+ withSQLConf((GlutenConfig.NATIVE_WRITER_ENABLED.key, "false")) {
+ withNativeWriteCheck(checkNative = false) {
+ spark
+ .table("origin_table")
+ .write
+ .format(format)
+ .partitionBy("part=col1", "part_col2")
+ .bucketBy(2, "bucket/col")
+ .saveAsTable(table_name_vanilla)
+ }
+ compareWriteFilesSignature(format, table_name, table_name_vanilla,
"sum(id)")
+ }
+ }
+ }
+
test("test bucketed by constant") {
nativeWrite {
format =>
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
index 6d404fe3aa..faffe19136 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/mergetree/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala
@@ -21,12 +21,13 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.{FileSourceScanExecTransformer,
GlutenClickHouseTPCHAbstractSuite}
import org.apache.spark.SparkConf
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.delta.catalog.ClickHouseTableV2
import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.mergetree.StorageMeta
import
org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
+import org.apache.spark.sql.types._
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
@@ -359,6 +360,54 @@ class GlutenClickHouseMergeTreeWriteOnHDFSSuite
spark.sql("drop table lineitem_mergetree_partition_hdfs")
}
+ test("test partition values with escape chars") {
+
+ val schema = StructType(
+ Seq(
+ StructField.apply("id", IntegerType, nullable = true),
+ StructField.apply("escape", StringType, nullable = true)
+ ))
+
+ // scalastyle:off nonascii
+ val data: Seq[Row] = Seq(
+ Row(1, "="),
+ Row(2, "/"),
+ Row(3, "#"),
+ Row(4, ":"),
+ Row(5, "\\"),
+ Row(6, "\u0001"),
+ Row(7, "中文"),
+ Row(8, " "),
+ Row(9, "a b")
+ )
+ // scalastyle:on nonascii
+
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data),
schema)
+ df.createOrReplaceTempView("origin_table")
+
+ // spark.conf.set("spark.gluten.enabled", "false")
+ spark.sql(s"""
+ |DROP TABLE IF EXISTS partition_escape;
+ |""".stripMargin)
+
+ spark.sql(s"""
+ |CREATE TABLE IF NOT EXISTS partition_escape
+ |(
+ | c1 int,
+ | c2 string
+ |)
+ |USING clickhouse
+ |PARTITIONED BY (c2)
+ |TBLPROPERTIES (storage_policy='__hdfs_main',
+ | orderByKey='c1',
+ | primaryKey='c1')
+ |LOCATION '$HDFS_URL/test/partition_escape'
+ |""".stripMargin)
+
+ spark.sql("insert into partition_escape select * from origin_table")
+ spark.sql("select * from partition_escape").show()
+ }
+
testSparkVersionLE33("test mergetree write with bucket table") {
spark.sql(s"""
|DROP TABLE IF EXISTS lineitem_mergetree_bucket_hdfs;
diff --git
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
index 384780e7d2..481e340d87 100644
---
a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala
@@ -25,12 +25,21 @@ import
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FakeRowAdaptor
import org.apache.spark.sql.util.QueryExecutionListener
+import java.io.File
+
+import scala.reflect.runtime.universe.TypeTag
+
trait NativeWriteChecker
extends GlutenClickHouseWholeStageTransformerSuite
with AdaptiveSparkPlanHelper {
private val formats: Seq[String] = Seq("orc", "parquet")
+ // test non-ascii path, by the way
+ // scalastyle:off nonascii
+ protected def getWarehouseDir: String = basePath + "/中文/spark-warehouse"
+ // scalastyle:on nonascii
+
def withNativeWriteCheck(checkNative: Boolean)(block: => Unit): Unit = {
var nativeUsed = false
@@ -82,6 +91,22 @@ trait NativeWriteChecker
}
}
+ def nativeWrite2(
+ f: String => (String, String, String),
+ extraCheck: (String, String) => Unit = null,
+ checkNative: Boolean = true): Unit = nativeWrite {
+ format =>
+ val (table_name, table_create_sql, insert_sql) = f(format)
+ withDestinationTable(table_name, Option(table_create_sql)) {
+ checkInsertQuery(insert_sql, checkNative)
+ Option(extraCheck).foreach(_(table_name, format))
+ }
+ }
+
+ def withSource[A <: Product: TypeTag](data: Seq[A], viewName: String, pairs:
(String, String)*)(
+ block: => Unit): Unit =
+ withSource(spark.createDataFrame(data), viewName, pairs: _*)(block)
+
def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)(
block: => Unit): Unit = {
withSQLConf(pairs: _*) {
@@ -91,4 +116,64 @@ trait NativeWriteChecker
}
}
}
+
+ def getColumnName(col: String): String = {
+ col.replaceAll("\\(", "_").replaceAll("\\)", "_")
+ }
+
+ def compareSource(originTable: String, table: String, fields: Seq[String]):
Unit = {
+ def query(table: String, selectFields: Seq[String]): String = {
+ s"select ${selectFields.mkString(",")} from $table"
+ }
+ val expectedRows = spark.sql(query(originTable, fields)).collect()
+ val actual = spark.sql(query(table, fields.map(getColumnName)))
+ checkAnswer(actual, expectedRows)
+ }
+
+ def writeAndCheckRead(
+ original_table: String,
+ table_name: String,
+ fields: Seq[String],
+ checkNative: Boolean = true)(write: Seq[String] => Unit): Unit = {
+ withDestinationTable(table_name) {
+ withNativeWriteCheck(checkNative) {
+ write(fields)
+ }
+ compareSource(original_table, table_name, fields)
+ }
+ }
+
+ def compareWriteFilesSignature(
+ format: String,
+ table: String,
+ vanillaTable: String,
+ sigExpr: String): Unit = {
+ val tableFiles = recursiveListFiles(new File(getWarehouseDir + "/" +
table))
+ .filter(_.getName.endsWith(s".$format"))
+ val sigsOfNativeWriter = getSignature(format, tableFiles, sigExpr).sorted
+ val vanillaTableFiles = recursiveListFiles(new File(getWarehouseDir + "/"
+ vanillaTable))
+ .filter(_.getName.endsWith(s".$format"))
+ val sigsOfVanillaWriter = getSignature(format, vanillaTableFiles,
sigExpr).sorted
+ assertResult(sigsOfVanillaWriter)(sigsOfNativeWriter)
+ }
+
+ def recursiveListFiles(f: File): Array[File] = {
+ val these = f.listFiles
+ these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
+ }
+
+ def getSignature(
+ format: String,
+ writeFiles: Array[File],
+ sigExpr: String): Array[(Long, Long)] = {
+ writeFiles.map(
+ f => {
+ val df = if (format.equals("parquet")) {
+ spark.read.parquet(f.getAbsolutePath)
+ } else {
+ spark.read.orc(f.getAbsolutePath)
+ }
+ (df.count(),
df.selectExpr(sigExpr).collect().apply(0).apply(0).asInstanceOf[Long])
+ })
+ }
}
diff --git a/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp
b/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp
new file mode 100644
index 0000000000..522f9ddee2
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkPartitionEscape.cpp
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "SparkPartitionEscape.h"
+#include <Functions/FunctionFactory.h>
+#include <Common/Exception.h>
+#include <DataTypes/IDataType.h>
+#include <DataTypes/DataTypeString.h>
+#include <sstream>
+#include <iomanip>
+#include <string>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+const std::vector<char> SparkPartitionEscape::ESCAPE_CHAR_LIST = {
+ '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007',
'\u0008', '\u0009',
+ '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011',
'\u0012', '\u0013',
+ '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A',
'\u001B', '\u001C',
+ '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=',
'?', '\\', '\u007F',
+ '{', '[', ']', '^'
+};
+
+const std::bitset<128> SparkPartitionEscape::ESCAPE_BITSET = []()
+{
+ std::bitset<128> bitset;
+ for (char c : SparkPartitionEscape::ESCAPE_CHAR_LIST)
+ {
+ bitset.set(c);
+ }
+#ifdef _WIN32
+ bitset.set(' ');
+ bitset.set('<');
+ bitset.set('>');
+ bitset.set('|');
+#endif
+ return bitset;
+}();
+
+static bool needsEscaping(char c) {
+ return c >= 0 && c < SparkPartitionEscape::ESCAPE_BITSET.size()
+ && SparkPartitionEscape::ESCAPE_BITSET.test(c);
+}
+
+static std::string escapePathName(const std::string & path) {
+ std::ostringstream builder;
+ for (char c : path) {
+ if (needsEscaping(c)) {
+ builder << '%' << std::uppercase << std::setw(2) <<
std::setfill('0') << std::hex << (int)c;
+ } else {
+ builder << c;
+ }
+ }
+
+ return builder.str();
+}
+
+DB::DataTypePtr SparkPartitionEscape::getReturnTypeImpl(const DB::DataTypes &
arguments) const
+{
+ if (arguments.size() != 1)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} argument size must be 1", name);
+
+ if (!isString(arguments[0]))
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument of function {} must be String", getName());
+
+ return std::make_shared<DataTypeString>();
+}
+
+DB::ColumnPtr SparkPartitionEscape::executeImpl(
+ const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &
result_type, size_t input_rows_count) const
+{
+ auto result = result_type->createColumn();
+ result->reserve(input_rows_count);
+
+ for (size_t i = 0; i < input_rows_count; ++i)
+ {
+ auto escaped_name =
escapePathName(arguments[0].column->getDataAt(i).toString());
+ result->insertData(escaped_name.c_str(), escaped_name.size());
+ }
+ return result;
+}
+
+REGISTER_FUNCTION(SparkPartitionEscape)
+{
+ factory.registerFunction<SparkPartitionEscape>();
+}
+}
diff --git a/cpp-ch/local-engine/Functions/SparkPartitionEscape.h
b/cpp-ch/local-engine/Functions/SparkPartitionEscape.h
new file mode 100644
index 0000000000..916134506a
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkPartitionEscape.h
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Columns/IColumn.h>
+#include <Core/ColumnsWithTypeAndName.h>
+#include <DataTypes/DataTypeDate.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/IDataType.h>
+#include <Functions/IFunction.h>
+#include <Interpreters/Context.h>
+#include <bitset>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+using namespace DB;
+
+namespace local_engine
+{
+
+class SparkPartitionEscape : public DB::IFunction
+{
+public:
+ static const std::vector<char> ESCAPE_CHAR_LIST;
+ static const std::bitset<128> ESCAPE_BITSET;
+ static constexpr auto name = "sparkPartitionEscape";
+ static FunctionPtr create(ContextPtr /*context*/) { return
std::make_shared<SparkPartitionEscape>(); }
+ SparkPartitionEscape() = default;
+ ~SparkPartitionEscape() override = default;
+ String getName() const override { return name; }
+ size_t getNumberOfArguments() const override { return 1; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+ DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & /*arguments*/)
const override;
+ DB::ColumnPtr executeImpl(
+ const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &
result_type, size_t /*input_rows_count*/) const override;
+};
+
+}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index ba2129deb1..841e51c00a 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -130,6 +130,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Uuid, uuid,
generateUUIDv4);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Levenshtein, levenshtein,
editDistanceUTF8);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(FormatString, format_string, printf);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(SoundEx, soundex, soundex);
+REGISTER_COMMON_SCALAR_FUNCTION_PARSER(PartitionEscape, partition_escape,
sparkPartitionEscape);
// hash functions
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Crc32, crc32, CRC32);
diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
index ef4d5504ff..b8e1a3cbd4 100644
--- a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
+++ b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeMeta.cpp
@@ -212,9 +212,7 @@ MergeTreeTableInstance::MergeTreeTableInstance(const
std::string & info) : Merge
while (!in.eof())
{
MergeTreePart part;
- std::string encoded_name;
- readString(encoded_name, in);
- Poco::URI::decode(encoded_name, part.name);
+ readString(part.name, in);
assertChar('\n', in);
readIntText(part.begin, in);
assertChar('\n', in);
diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
index c0c762906a..77096f3f49 100644
--- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
+++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h
@@ -416,17 +416,23 @@ public:
for (const auto & column : partition_columns)
{
// partition_column=
- std::string key = add_slash ? fmt::format("/{}=", column) :
fmt::format("{}=", column);
+ auto column_name = std::make_shared<DB::ASTLiteral>(column);
+ auto escaped_name = makeASTFunction("sparkPartitionEscape",
DB::ASTs{column_name});
+ if (add_slash)
+ arguments.emplace_back(std::make_shared<DB::ASTLiteral>("/"));
add_slash = true;
- arguments.emplace_back(std::make_shared<DB::ASTLiteral>(key));
+ arguments.emplace_back(escaped_name);
+ arguments.emplace_back(std::make_shared<DB::ASTLiteral>("="));
// ifNull(toString(partition_column), DEFAULT_PARTITION_NAME)
// FIXME if toString(partition_column) is empty
- auto column_ast = std::make_shared<DB::ASTIdentifier>(column);
+ auto column_ast = makeASTFunction("toString",
DB::ASTs{std::make_shared<DB::ASTIdentifier>(column)});
+ auto escaped_value = makeASTFunction("sparkPartitionEscape",
DB::ASTs{column_ast});
DB::ASTs if_null_args{
- makeASTFunction("toString", DB::ASTs{column_ast}),
std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
+ makeASTFunction("toString", DB::ASTs{escaped_value}),
std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
arguments.emplace_back(makeASTFunction("ifNull",
std::move(if_null_args)));
}
+
if (isBucketedWrite(input_header))
{
DB::ASTs args{std::make_shared<DB::ASTLiteral>("%05d"),
std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)};
diff --git a/cpp-ch/local-engine/tests/CMakeLists.txt
b/cpp-ch/local-engine/tests/CMakeLists.txt
index 09ca32a01a..9c18d70b0f 100644
--- a/cpp-ch/local-engine/tests/CMakeLists.txt
+++ b/cpp-ch/local-engine/tests/CMakeLists.txt
@@ -107,6 +107,7 @@ if(ENABLE_BENCHMARKS)
benchmark_spark_row.cpp
benchmark_unix_timestamp_function.cpp
benchmark_spark_functions.cpp
+ benchmark_spark_partition_escape_function.cpp
benchmark_cast_float_function.cpp
benchmark_to_datetime_function.cpp
benchmark_spark_divide_function.cpp
diff --git
a/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp
b/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp
new file mode 100644
index 0000000000..3299eb9fee
--- /dev/null
+++ b/cpp-ch/local-engine/tests/benchmark_spark_partition_escape_function.cpp
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Core/Block.h>
+#include <DataTypes/DataTypeFactory.h>
+#include <Functions/FunctionFactory.h>
+#include <Parser/FunctionParser.h>
+#include <benchmark/benchmark.h>
+#include <Common/QueryContext.h>
+
+using namespace DB;
+
+static Block createDataBlock(size_t rows)
+{
+ auto type = DataTypeFactory::instance().get("String");
+ auto column = type->createColumn();
+ for (size_t i = 0; i < rows; ++i)
+ {
+ char ch = static_cast<char>(i % 128);
+ std::string str = "escape_" + ch;
+ column->insert(str);
+ }
+ Block block;
+ block.insert(ColumnWithTypeAndName(std::move(column), type, "d"));
+ return std::move(block);
+}
+
+static void BM_CHSparkPartitionEscape(benchmark::State & state)
+{
+ using namespace DB;
+ auto & factory = FunctionFactory::instance();
+ auto function = factory.get("sparkPartitionEscape",
local_engine::QueryContext::globalContext());
+ Block block = createDataBlock(1000000);
+ auto executable = function->build(block.getColumnsWithTypeAndName());
+ for (auto _ : state) [[maybe_unused]]
+ auto result = executable->execute(block.getColumnsWithTypeAndName(),
executable->getResultType(), block.rows(), false);
+}
+
+BENCHMARK(BM_CHSparkPartitionEscape)->Unit(benchmark::kMillisecond)->Iterations(50);
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]