This is an automated email from the ASF dual-hosted git repository. mahongbin pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push: new 7ff18ee85 [Gluten-4912][CH]Support Specifying columns in clickhouse tables to be Low Cardinality (#4925) 7ff18ee85 is described below commit 7ff18ee85d2c66d3ca8c3936bea457070fb450ac Author: Hongbin Ma <mahong...@apache.org> AuthorDate: Tue Mar 12 15:30:44 2024 +0800 [Gluten-4912][CH]Support Specifying columns in clickhouse tables to be Low Cardinality (#4925) --- .../source/DeltaMergeTreeFileFormat.scala | 4 + .../source/DeltaMergeTreeFileFormat.scala | 4 + .../backendsapi/clickhouse/CHIteratorApi.scala | 1 + .../execution/GlutenMergeTreePartition.scala | 1 + .../delta/ClickhouseOptimisticTransaction.scala | 2 + .../sql/delta/catalog/ClickHouseTableV2.scala | 22 +++ .../utils/MergeTreePartsPartitionsUtil.scala | 11 ++ .../datasources/v1/CHMergeTreeWriterInjects.scala | 9 + .../v1/clickhouse/MergeTreeFileFormatWriter.scala | 3 + .../GlutenClickHouseMergeTreeWriteSuite.scala | 90 ++++++++++ .../apache/spark/affinity/MixedAffinitySuite.scala | 1 + .../local-engine/Builder/SerializedPlanBuilder.cpp | 8 +- cpp-ch/local-engine/Common/MergeTreeTool.cpp | 2 + cpp-ch/local-engine/Common/MergeTreeTool.h | 1 + cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp | 6 +- .../local-engine/Parser/SerializedPlanParser.cpp | 190 +++++++++++++++------ cpp-ch/local-engine/Parser/TypeParser.cpp | 24 ++- cpp-ch/local-engine/Parser/TypeParser.h | 3 +- cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp | 4 +- cpp-ch/local-engine/Storages/IO/NativeWriter.cpp | 5 +- .../substrait/rel/ExtensionTableBuilder.java | 2 + .../substrait/rel/ExtensionTableNode.java | 5 + .../datasource/GlutenFormatWriterInjects.scala | 1 + 23 files changed, 328 insertions(+), 71 deletions(-) diff --git a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala index 09f17c468..fef109d35 100644 --- a/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala +++ b/backends-clickhouse/src/main/delta-20/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala @@ -33,6 +33,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) protected var tableName = "" protected var dataSchemas = Seq.empty[Attribute] protected var orderByKeyOption: Option[Seq[String]] = None + protected var lowCardKeyOption: Option[Seq[String]] = None protected var primaryKeyOption: Option[Seq[String]] = None protected var partitionColumns: Seq[String] = Seq.empty[String] protected var clickhouseTableConfigs: Map[String, String] = Map.empty @@ -43,6 +44,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) tableName: String, schemas: Seq[Attribute], orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], clickhouseTableConfigs: Map[String, String], partitionColumns: Seq[String]) { @@ -51,6 +53,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) this.tableName = tableName this.dataSchemas = schemas this.orderByKeyOption = orderByKeyOption + this.lowCardKeyOption = lowCardKeyOption this.primaryKeyOption = primaryKeyOption this.clickhouseTableConfigs = clickhouseTableConfigs this.partitionColumns = partitionColumns @@ -98,6 +101,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) database, tableName, orderByKeyOption, + lowCardKeyOption, primaryKeyOption, partitionColumns, metadata.schema, diff --git a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala index 1dd341d2e..b87420787 100644 --- a/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala +++ b/backends-clickhouse/src/main/delta-22/org/apache/spark/sql/execution/datasources/v2/clickhouse/source/DeltaMergeTreeFileFormat.scala @@ -32,6 +32,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends DeltaParquetFileForma protected var tableName = "" protected var dataSchemas = Seq.empty[Attribute] protected var orderByKeyOption: Option[Seq[String]] = None + protected var lowCardKeyOption: Option[Seq[String]] = None protected var primaryKeyOption: Option[Seq[String]] = None protected var partitionColumns: Seq[String] = Seq.empty[String] protected var clickhouseTableConfigs: Map[String, String] = Map.empty @@ -42,6 +43,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends DeltaParquetFileForma tableName: String, schemas: Seq[Attribute], orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], clickhouseTableConfigs: Map[String, String], partitionColumns: Seq[String]) { @@ -50,6 +52,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends DeltaParquetFileForma this.tableName = tableName this.dataSchemas = schemas this.orderByKeyOption = orderByKeyOption + this.lowCardKeyOption = lowCardKeyOption this.primaryKeyOption = primaryKeyOption this.clickhouseTableConfigs = clickhouseTableConfigs this.partitionColumns = partitionColumns @@ -97,6 +100,7 @@ class DeltaMergeTreeFileFormat(metadata: Metadata) extends DeltaParquetFileForma database, tableName, orderByKeyOption, + lowCardKeyOption, primaryKeyOption, partitionColumns, metadata.schema, diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index b7ba20b64..1841faccf 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -97,6 +97,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { p.relativeTablePath, p.absoluteTablePath, p.orderByKey, + p.lowCardKey, p.primaryKey, partLists, starts, diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala index 16c71cb09..df41191f2 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/GlutenMergeTreePartition.scala @@ -38,6 +38,7 @@ case class GlutenMergeTreePartition( relativeTablePath: String, absoluteTablePath: String, orderByKey: String, + lowCardKey: String, primaryKey: String, partList: Array[MergeTreePartSplit], tableSchemaJson: String, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala index b8af5e414..e4786168e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/ClickhouseOptimisticTransaction.scala @@ -130,6 +130,7 @@ class ClickhouseOptimisticTransaction( tableV2.tableName, output, tableV2.orderByKeyOption, + tableV2.lowCardKeyOption, tableV2.primaryKeyOption, tableV2.clickhouseTableConfigs, tableV2.partitionColumns @@ -142,6 +143,7 @@ class ClickhouseOptimisticTransaction( spark.sessionState.newHadoopConfWithOptions(metadata.configuration ++ deltaLog.options), // scalastyle:on deltahadoopconfiguration orderByKeyOption = tableV2.orderByKeyOption, + lowCardKeyOption = tableV2.lowCardKeyOption, primaryKeyOption = tableV2.primaryKeyOption, partitionColumns = partitioningColumns, bucketSpec = tableV2.bucketOption, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala index 148939ede..e06a01edf 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/delta/catalog/ClickHouseTableV2.scala @@ -113,6 +113,27 @@ class ClickHouseTableV2( } } + lazy val lowCardKeyOption: Option[Seq[String]] = { + val tableProperties = properties() + if (tableProperties.containsKey("lowCardKey")) { + if (tableProperties.get("lowCardKey").nonEmpty) { + val lowCardKeys = tableProperties.get("lowCardKey").split(",").map(_.trim).toSeq + lowCardKeys.foreach( + s => { + if (s.contains(".")) { + throw new IllegalStateException( + s"lowCardKey $s can not contain '.' (not support nested column yet)") + } + }) + Some(lowCardKeys.map(s => s.toLowerCase())) + } else { + None + } + } else { + None + } + } + lazy val orderByKeyOption: Option[Seq[String]] = { if (bucketOption.isDefined && bucketOption.get.sortColumnNames.nonEmpty) { val orderByKes = bucketOption.get.sortColumnNames @@ -240,6 +261,7 @@ class ClickHouseTableV2( tableName, Seq.empty[Attribute], orderByKeyOption, + lowCardKeyOption, primaryKeyOption, clickhouseTableConfigs, partitionColumns) diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala index 85d2b1176..a7ac2ce16 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/utils/MergeTreePartsPartitionsUtil.scala @@ -69,6 +69,11 @@ object MergeTreePartsPartitionsUtil extends Logging { val (orderByKey, primaryKey) = MergeTreeDeltaUtil.genOrderByAndPrimaryKeyStr(table.orderByKeyOption, table.primaryKeyOption) + val lowCardKey = table.lowCardKeyOption match { + case Some(keys) => keys.mkString(",") + case None => "" + } + val tableSchemaJson = ConverterUtils.convertNamedStructJson(table.schema()) // bucket table @@ -86,6 +91,7 @@ object MergeTreePartsPartitionsUtil extends Logging { tableSchemaJson, partitions, orderByKey, + lowCardKey, primaryKey, table.clickhouseTableConfigs, sparkSession @@ -102,6 +108,7 @@ object MergeTreePartsPartitionsUtil extends Logging { tableSchemaJson, partitions, orderByKey, + lowCardKey, primaryKey, table.clickhouseTableConfigs, sparkSession @@ -121,6 +128,7 @@ object MergeTreePartsPartitionsUtil extends Logging { tableSchemaJson: String, partitions: ArrayBuffer[InputPartition], orderByKey: String, + lowCardKey: String, primaryKey: String, clickhouseTableConfigs: Map[String, String], sparkSession: SparkSession): Unit = { @@ -205,6 +213,7 @@ object MergeTreePartsPartitionsUtil extends Logging { relativeTablePath, absoluteTablePath, orderByKey, + lowCardKey, primaryKey, currentFiles.toArray, tableSchemaJson, @@ -246,6 +255,7 @@ object MergeTreePartsPartitionsUtil extends Logging { tableSchemaJson: String, partitions: ArrayBuffer[InputPartition], orderByKey: String, + lowCardKey: String, primaryKey: String, clickhouseTableConfigs: Map[String, String], sparkSession: SparkSession): Unit = { @@ -307,6 +317,7 @@ object MergeTreePartsPartitionsUtil extends Logging { relativeTablePath, absoluteTablePath, orderByKey, + lowCardKey, primaryKey, currentFiles.toArray, tableSchemaJson, diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala index ffc2ea1b8..d05945a81 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/CHMergeTreeWriterInjects.scala @@ -67,6 +67,7 @@ class CHMergeTreeWriterInjects extends GlutenFormatWriterInjectsBase { database: String, tableName: String, orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], partitionColumns: Seq[String], tableSchema: StructType, @@ -81,6 +82,7 @@ class CHMergeTreeWriterInjects extends GlutenFormatWriterInjectsBase { database, tableName, orderByKeyOption, + lowCardKeyOption, primaryKeyOption, partitionColumns, ConverterUtils.convertNamedStructJson(tableSchema), @@ -122,6 +124,7 @@ object CHMergeTreeWriterInjects { database: String, tableName: String, orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], partitionColumns: Seq[String], tableSchemaJson: String, @@ -143,6 +146,11 @@ object CHMergeTreeWriterInjects { primaryKeyOption ) + val lowCardKey = lowCardKeyOption match { + case Some(keys) => keys.mkString(",") + case None => "" + } + val substraitContext = new SubstraitContext val extensionTableNode = ExtensionTableBuilder.makeExtensionTable( -1, @@ -152,6 +160,7 @@ object CHMergeTreeWriterInjects { path, "", orderByKey, + lowCardKey, primaryKey, new JList[String](), new JList[JLong](), diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala index a827313e6..874a4ede3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/datasources/v1/clickhouse/MergeTreeFileFormatWriter.scala @@ -56,6 +56,7 @@ object MergeTreeFileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -69,6 +70,7 @@ object MergeTreeFileFormatWriter extends Logging { outputSpec = outputSpec, hadoopConf = hadoopConf, orderByKeyOption = orderByKeyOption, + lowCardKeyOption = lowCardKeyOption, primaryKeyOption = primaryKeyOption, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -86,6 +88,7 @@ object MergeTreeFileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala index fcf1b3e76..2960b502b 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseMergeTreeWriteSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMerg import java.io.File +import scala.io.Source + // Some sqls' line length exceeds 100 // scalastyle:off line.size.limit @@ -1200,5 +1202,93 @@ class GlutenClickHouseMergeTreeWriteSuite } + test("test mergetree table with low cardinality column") { + spark.sql(s""" + |DROP TABLE IF EXISTS lineitem_mergetree_lowcard; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS lineitem_mergetree_lowcard + |( + | l_orderkey bigint, + | l_partkey bigint, + | l_suppkey bigint, + | l_linenumber bigint, + | l_quantity double, + | l_extendedprice double, + | l_discount double, + | l_tax double, + | l_returnflag string, + | l_linestatus string, + | l_shipdate date, + | l_commitdate date, + | l_receiptdate date, + | l_shipinstruct string, + | l_shipmode string, + | l_comment string + |) + |USING clickhouse + |LOCATION '$basePath/lineitem_mergetree_lowcard' + |TBLPROPERTIES('lowCardKey'='l_returnflag,L_LINESTATUS') + |""".stripMargin) + + spark.sql(s""" + | insert into table lineitem_mergetree_lowcard + | select * from lineitem + |""".stripMargin) + + val sqlStr = + s""" + |SELECT + | l_returnflag, + | l_linestatus, + | sum(l_quantity) AS sum_qty, + | sum(l_extendedprice) AS sum_base_price, + | sum(l_extendedprice * (1 - l_discount)) AS sum_disc_price, + | sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) AS sum_charge, + | avg(l_quantity) AS avg_qty, + | avg(l_extendedprice) AS avg_price, + | avg(l_discount) AS avg_disc, + | count(*) AS count_order + |FROM + | lineitem_mergetree_lowcard + |WHERE + | l_shipdate <= date'1998-09-02' - interval 1 day + |GROUP BY + | l_returnflag, + | l_linestatus + |ORDER BY + | l_returnflag, + | l_linestatus; + | + |""".stripMargin + runTPCHQueryBySQL(1, sqlStr) { _ => {} } + val directory = new File(s"$basePath/lineitem_mergetree_lowcard") + // find a folder whose name is like 48b70783-b3b8-4bf8-9c52-5261aead8e3e_0_006 + val partDir = directory.listFiles().filter(f => f.getName.length > 20).head + val columnsFile = new File(partDir, "columns.txt") + val columns = Source.fromFile(columnsFile).getLines().mkString + assert(columns.contains("`l_returnflag` LowCardinality(Nullable(String))")) + assert(columns.contains("`l_linestatus` LowCardinality(Nullable(String))")) + + // test low card column in measure + val sqlStr2 = + s""" + |SELECT + | max(l_returnflag) + |FROM + | lineitem_mergetree_lowcard + |GROUP BY + | l_linestatus + | order by l_linestatus + | + |""".stripMargin + + assert( + // total rows should remain unchanged + spark.sql(sqlStr2).collect().apply(0).get(0) == "R" + ) + } + } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala b/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala index c578367d3..e6910a430 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/affinity/MixedAffinitySuite.scala @@ -58,6 +58,7 @@ class MixedAffinitySuite extends QueryTest with SharedSparkSession { "fakePath2", "", "", + "", Array(file), "", Map.empty) diff --git a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp index 214a52065..be5a9ca4b 100644 --- a/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp +++ b/cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp @@ -17,6 +17,7 @@ #include "SerializedPlanBuilder.h" #include <DataTypes/DataTypeArray.h> #include <DataTypes/DataTypeDateTime64.h> +#include <DataTypes/DataTypeLowCardinality.h> #include <DataTypes/DataTypeMap.h> #include <DataTypes/DataTypeNullable.h> #include <DataTypes/DataTypeTuple.h> @@ -220,12 +221,15 @@ SerializedPlanBuilder & SerializedPlanBuilder::project(const std::vector<substra std::shared_ptr<substrait::Type> SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type) { - const auto * ch_type_nullable = checkAndGetDataType<DataTypeNullable>(ch_type.get()); + const auto ch_type_wo_lowcardinality = DB::removeLowCardinality(ch_type); + + const auto * ch_type_nullable = checkAndGetDataType<DataTypeNullable>(ch_type_wo_lowcardinality.get()); + const bool is_nullable = (ch_type_nullable != nullptr); auto type_nullability = is_nullable ? substrait::Type_Nullability_NULLABILITY_NULLABLE : substrait::Type_Nullability_NULLABILITY_REQUIRED; - const auto ch_type_without_nullable = DB::removeNullable(ch_type); + const auto ch_type_without_nullable = DB::removeNullable(ch_type_wo_lowcardinality); const DB::WhichDataType which(ch_type_without_nullable); auto res = std::make_shared<substrait::Type>(); diff --git a/cpp-ch/local-engine/Common/MergeTreeTool.cpp b/cpp-ch/local-engine/Common/MergeTreeTool.cpp index d1727b740..2f6b4602d 100644 --- a/cpp-ch/local-engine/Common/MergeTreeTool.cpp +++ b/cpp-ch/local-engine/Common/MergeTreeTool.cpp @@ -91,6 +91,8 @@ MergeTreeTable parseMergeTreeTableString(const std::string & info) readString(table.primary_key, in); assertChar('\n', in); } + readString(table.low_card_key, in); + assertChar('\n', in); readString(table.relative_path, in); assertChar('\n', in); readString(table.absolute_path, in); diff --git a/cpp-ch/local-engine/Common/MergeTreeTool.h b/cpp-ch/local-engine/Common/MergeTreeTool.h index e410e50f6..a6af7ebca 100644 --- a/cpp-ch/local-engine/Common/MergeTreeTool.h +++ b/cpp-ch/local-engine/Common/MergeTreeTool.h @@ -50,6 +50,7 @@ struct MergeTreeTable std::string table; substrait::NamedStruct schema; std::string order_by_key; + std::string low_card_key; std::string primary_key = ""; std::string relative_path; std::string absolute_path; diff --git a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp index 405d5c5f5..8737ecb7d 100644 --- a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp +++ b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp @@ -26,6 +26,8 @@ #include "MergeTreeRelParser.h" +#include <Poco/StringTokenizer.h> + namespace DB { @@ -69,7 +71,7 @@ CustomStorageMergeTreePtr MergeTreeRelParser::parseStorage( auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value()); DB::Block header; chassert(rel.has_base_schema()); - header = TypeParser::buildBlockFromNamedStruct(rel.base_schema()); + header = TypeParser::buildBlockFromNamedStruct(rel.base_schema(), merge_tree_table.low_card_key); auto names_and_types_list = header.getNamesAndTypesList(); auto storage_factory = StorageMergeTreeFactory::instance(); auto metadata = buildMetaData(names_and_types_list, context, merge_tree_table); @@ -105,7 +107,7 @@ MergeTreeRelParser::parseReadRel( table.ParseFromString(extension_table.detail().value()); auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value()); DB::Block header; - header = TypeParser::buildBlockFromNamedStruct(merge_tree_table.schema); + header = TypeParser::buildBlockFromNamedStruct(merge_tree_table.schema, merge_tree_table.low_card_key); DB::Block input; if (rel.has_base_schema() && rel.base_schema().names_size()) { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index b3295be82..7bdc80ca8 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -95,14 +95,14 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; - extern const int NO_SUCH_DATA_PART; - extern const int UNKNOWN_FUNCTION; - extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int INVALID_JOIN_ON_EXPRESSION; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; +extern const int NO_SUCH_DATA_PART; +extern const int UNKNOWN_FUNCTION; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int INVALID_JOIN_ON_EXPRESSION; } } @@ -152,13 +152,16 @@ void SerializedPlanParser::parseExtensions( if (extension.has_extension_function()) { function_mapping.emplace( - std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name()); + std::to_string(extension.extension_function().function_anchor()), + extension.extension_function().name()); } } } std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG( - const std::vector<substrait::Expression> & expressions, const DB::Block & header, const DB::Block & read_schema) + const std::vector<substrait::Expression> & expressions, + const DB::Block & header, + const DB::Block & read_schema) { auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); NamesWithAliases required_columns; @@ -263,7 +266,8 @@ std::string getDecimalFunction(const substrait::Type_Decimal & decimal, bool nul bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel) { - return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with("iterator"); + return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with( + "iterator"); } bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel) @@ -311,7 +315,10 @@ QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait auto iter_index = std::stoi(iter.substr(pos + 1, iter.size())); auto source = std::make_shared<SourceFromJavaIter>( - context, TypeParser::buildBlockFromNamedStruct(rel.base_schema()), input_iters[iter_index], materialize_inputs[iter_index]); + context, + TypeParser::buildBlockFromNamedStruct(rel.base_schema()), + input_iters[iter_index], + materialize_inputs[iter_index]); QueryPlanStepPtr source_step = std::make_unique<ReadFromPreparedSource>(Pipe(source)); source_step->setStepDescription("Read From Java Iter"); return source_step; @@ -368,6 +375,24 @@ DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr n DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) { + if (nullable && !nested_type->isNullable()) + { + if (nested_type->isLowCardinalityNullable()) + { + return nested_type; + } + else + { + if (!nested_type->lowCardinality()) + return std::make_shared<DataTypeNullable>(nested_type); + else + return std::make_shared<DataTypeLowCardinality>( + std::make_shared<DataTypeNullable>( + dynamic_cast<const DataTypeLowCardinality &>(*nested_type).getDictionaryType())); + } + } + + if (nullable && !nested_type->isNullable()) return std::make_shared<DataTypeNullable>(nested_type); else @@ -428,13 +453,15 @@ QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan) { if (type->isNullable()) { - final_cols.emplace_back(type->createColumn(), std::make_shared<DB::DataTypeNullable>(col.type), col.name); + auto wrapped = wrapNullableType(true, col.type); + final_cols.emplace_back(type->createColumn(), wrapped, col.name); + need_final_project = !wrapped->equals(*col.type); } else { final_cols.emplace_back(type->createColumn(), DB::removeNullable(col.type), col.name); + need_final_project = true; } - need_final_project = true; } else { @@ -580,7 +607,9 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co { if (args.size() != 2) throw Exception( - ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString()); + ErrorCodes::BAD_ARGUMENTS, + "Spark function extract requires two args, function:{}", + function.ShortDebugString()); // Get the first arg: field const auto & extract_field = args.at(0); @@ -696,7 +725,9 @@ void SerializedPlanParser::parseArrayJoinArguments( /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1 if (scalar_function.arguments_size() != 1) throw Exception( - ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 instead of {}", scalar_function.arguments_size()); + ErrorCodes::BAD_ARGUMENTS, + "Argument number of arrayJoin should be 1 instead of {}", + scalar_function.arguments_size()); auto function_name_copy = function_name; parseFunctionArguments(actions_dag, parsed_args, function_name_copy, scalar_function); @@ -735,7 +766,11 @@ void SerializedPlanParser::parseArrayJoinArguments( } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( - const substrait::Expression & rel, std::vector<String> & result_names, DB::ActionsDAGPtr actions_dag, bool keep_result, bool position) + const substrait::Expression & rel, + std::vector<String> & result_names, + DB::ActionsDAGPtr actions_dag, + bool keep_result, + bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -759,8 +794,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared<DataTypeUInt32>(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * - { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -852,7 +886,10 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( - const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag, bool keep_result) + const substrait::Expression & rel, + std::string & result_name, + DB::ActionsDAGPtr actions_dag, + bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -868,7 +905,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( if (func_parser) { LOG_DEBUG( - &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); + &Poco::Logger::get("SerializedPlanParser"), + "parse function {} by function parser: {}", + func_name, + func_parser->getName()); const auto * result_node = func_parser->parse(scalar_function, actions_dag); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); @@ -920,10 +960,12 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared<DataTypeUInt32>(); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back( + &actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back( + &actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } else if (startsWith(function_signature, "make_decimal:")) @@ -938,10 +980,12 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared<DataTypeUInt32>(); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back(&actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back( + &actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back( + &actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } @@ -960,8 +1004,9 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() + ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name, DB::CastType::accurateOrNull); } @@ -971,8 +1016,9 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() + ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name); } } @@ -987,7 +1033,9 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( } bool SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs( - ActionsDAGPtr actions_dag, ActionsDAG::NodeRawConstPtrs & args, const substrait::Expression_ScalarFunction & arithmeticFun) + ActionsDAGPtr actions_dag, + ActionsDAG::NodeRawConstPtrs & args, + const substrait::Expression_ScalarFunction & arithmeticFun) { auto function_signature = function_mapping.at(std::to_string(arithmeticFun.function_reference())); auto pos = function_signature.find(':'); @@ -1187,7 +1235,9 @@ void SerializedPlanParser::parseFunctionArgument( } const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( - DB::ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg) + DB::ActionsDAGPtr & actions_dag, + const std::string & function_name, + const substrait::FunctionArgument & arg) { const DB::ActionsDAG::Node * res; if (arg.value().has_scalar_function()) @@ -1207,7 +1257,7 @@ const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( // Convert signed integer index into unsigned integer index std::pair<DB::DataTypePtr, DB::Field> SerializedPlanParser::convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field) { -// For tupelElement, field index starts from 1, but int substrait plan, it starts from 0. + // For tupelElement, field index starts from 1, but int substrait plan, it starts from 0. #define UINT_CONVERT(type_ptr, field, type_name) \ if ((type_ptr)->getTypeId() == DB::TypeIndex::type_name) \ { \ @@ -1229,7 +1279,11 @@ std::pair<DB::DataTypePtr, DB::Field> SerializedPlanParser::convertStructFieldTy } ActionsDAGPtr SerializedPlanParser::parseFunction( - const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) + const Block & header, + const substrait::Expression & rel, + std::string & result_name, + ActionsDAGPtr actions_dag, + bool keep_result) { if (!actions_dag) actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); @@ -1239,7 +1293,11 @@ ActionsDAGPtr SerializedPlanParser::parseFunction( } ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression( - const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) + const Block & header, + const substrait::Expression & rel, + std::string & result_name, + ActionsDAGPtr actions_dag, + bool keep_result) { if (!actions_dag) actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header)); @@ -1320,8 +1378,7 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple( = &actions_dag->addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name); auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared<DataTypeUInt32>(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * - { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -1546,7 +1603,9 @@ std::pair<DataTypePtr, Field> SerializedPlanParser::parseLiteral(const substrait } default: { throw Exception( - ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); + ErrorCodes::UNKNOWN_TYPE, + "Unsupported spark literal type {}", + magic_enum::enum_name(literal.literal_type_case())); } } return std::make_pair(std::move(type), std::move(field)); @@ -1585,7 +1644,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act { function_node = toFunctionNode(actions_dag, "sparkToDate", args); } - else if(DB::isString(DB::removeNullable(args.back()->result_type)) && substrait_type.has_timestamp()) + else if (DB::isString(DB::removeNullable(args.back()->result_type)) && substrait_type.has_timestamp()) { function_node = toFunctionNode(actions_dag, "sparkToDateTime", args); } @@ -1597,7 +1656,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act else { DataTypePtr ch_type = TypeParser::parseType(substrait_type); - if(DB::isString(DB::removeNullable(ch_type)) && isDecimalOrNullableDecimal(args[0]->result_type)) + if (DB::isString(DB::removeNullable(ch_type)) && isDecimalOrNullableDecimal(args[0]->result_type)) { UInt8 scale = getDecimalScale(*DB::removeNullable(args[0]->result_type)); args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale))); @@ -1750,7 +1809,9 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(const std::string & split_info) { substrait::ReadRel::ExtensionTable extension_table; - google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size())); + google::protobuf::io::CodedInputStream coded_in( + reinterpret_cast<const uint8_t *>(split_info.data()), + static_cast<int>(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = extension_table.ParseFromCodedStream(&coded_in); @@ -1763,7 +1824,9 @@ substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(con substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std::string & split_info) { substrait::ReadRel::LocalFiles local_files; - google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size())); + google::protobuf::io::CodedInputStream coded_in( + reinterpret_cast<const uint8_t *>(split_info.data()), + static_cast<int>(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = local_files.ParseFromCodedStream(&coded_in); @@ -1808,7 +1871,8 @@ QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan) return parse(std::move(plan_ptr)); } -SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) +SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) + : context(context_) { } @@ -1817,10 +1881,13 @@ ContextMutablePtr SerializedPlanParser::global_context = nullptr; Context::ConfigurationPtr SerializedPlanParser::config = nullptr; void SerializedPlanParser::collectJoinKeys( - const substrait::Expression & condition, std::vector<std::pair<int32_t, int32_t>> & join_keys, int32_t right_key_start) + const substrait::Expression & condition, + std::vector<std::pair<int32_t, int32_t>> & join_keys, + int32_t right_key_start) { auto condition_name = getFunctionName( - function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function()); + function_mapping.at(std::to_string(condition.scalar_function().function_reference())), + condition.scalar_function()); if (condition_name == "and") { collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start); @@ -1878,7 +1945,9 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & } void ASTParser::parseFunctionArgumentsToAST( - const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args) + const Names & names, + const substrait::Expression_ScalarFunction & scalar_function, + ASTs & ast_args) { const auto & args = scalar_function.arguments(); @@ -2037,7 +2106,9 @@ void SerializedPlanParser::removeNullable(const std::set<String> & require_colum } void SerializedPlanParser::wrapNullable( - const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names) + const std::vector<String> & columns, + ActionsDAGPtr actions_dag, + std::map<std::string, std::string> & nullable_measure_names) { for (const auto & item : columns) { @@ -2087,7 +2158,8 @@ void LocalExecutor::execute(QueryPlanPtr query_plan) optimization_settings, BuildQueryPipelineSettings{ .actions_settings - = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes}, + = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, + .compile_expressions = CompileExpressions::yes}, .process_list_element = query_status}); LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}", PlanUtil::explainPlan(*current_query_plan)); @@ -2183,7 +2255,8 @@ Block & LocalExecutor::getHeader() } LocalExecutor::LocalExecutor(QueryContext & _query_context, ContextPtr context_) - : query_context(_query_context), context(context_) + : query_context(_query_context) + , context(context_) { } @@ -2211,8 +2284,12 @@ std::string LocalExecutor::dumpPipeline() } NonNullableColumnsResolver::NonNullableColumnsResolver( - const DB::Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_) - : header(header_), parser(parser_), cond_rel(cond_rel_) + const DB::Block & header_, + SerializedPlanParser & parser_, + const substrait::Expression & cond_rel_) + : header(header_) + , parser(parser_) + , cond_rel(cond_rel_) { } @@ -2284,7 +2361,8 @@ void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression & } std::string NonNullableColumnsResolver::safeGetFunctionName( - const std::string & function_signature, const substrait::Expression_ScalarFunction & function) + const std::string & function_signature, + const substrait::Expression_ScalarFunction & function) { try { @@ -2295,4 +2373,4 @@ std::string NonNullableColumnsResolver::safeGetFunctionName( return ""; } } -} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp b/cpp-ch/local-engine/Parser/TypeParser.cpp index 2edd8c1c8..3ffa0b0d9 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.cpp +++ b/cpp-ch/local-engine/Parser/TypeParser.cpp @@ -43,8 +43,8 @@ namespace DB { namespace ErrorCodes { - extern const int UNKNOWN_TYPE; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +extern const int UNKNOWN_TYPE; +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } } @@ -238,8 +238,13 @@ DB::DataTypePtr TypeParser::parseType(const substrait::Type & substrait_type, st } -DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & struct_) +DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & struct_, const std::string & low_card_cols) { + std::unordered_set<std::string> low_card_columns; + Poco::StringTokenizer tokenizer(low_card_cols, ","); + for (const auto & token : tokenizer) + low_card_columns.insert(token); + DB::ColumnsWithTypeAndName internal_cols; internal_cols.reserve(struct_.names_size()); std::list<std::string> field_names; @@ -252,6 +257,11 @@ DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & s const auto & substrait_type = struct_.struct_().types(i); auto ch_type = parseType(substrait_type, &field_names); + if (low_card_columns.contains(name)) + { + ch_type = std::make_shared<DB::DataTypeLowCardinality>(ch_type); + } + // This is a partial aggregate data column. // It's type is special, must be a struct type contains all arguments types. // Notice: there are some coincidence cases in which the type is not a struct type, e.g. name is "_1#913 + _2#914#928". We need to handle it. @@ -271,8 +281,8 @@ DB::Block TypeParser::buildBlockFromNamedStruct(const substrait::NamedStruct & s auto agg_function_name = function_parser->getCHFunctionName(args_types); auto action = NullsAction::EMPTY; ch_type = AggregateFunctionFactory::instance() - .get(agg_function_name, action, args_types, function_parser->getDefaultFunctionParameters(), properties) - ->getStateType(); + .get(agg_function_name, action, args_types, function_parser->getDefaultFunctionParameters(), properties) + ->getStateType(); } internal_cols.push_back(ColumnWithTypeAndName(ch_type, name)); @@ -295,7 +305,7 @@ DB::Block TypeParser::buildBlockFromNamedStructWithoutDFS(const substrait::Named for (int i = 0; i < size; ++i) { const auto & name = names[i]; - const auto & type = types[i]; + const auto & type = types[i]; auto ch_type = parseType(type); columns.emplace_back(ColumnWithTypeAndName(ch_type, name)); } @@ -327,4 +337,4 @@ DB::DataTypePtr TypeParser::tryWrapNullable(substrait::Type_Nullability nullable return std::make_shared<DB::DataTypeNullable>(nested_type); return nested_type; } -} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Parser/TypeParser.h b/cpp-ch/local-engine/Parser/TypeParser.h index a25b2f50a..55420ee1a 100644 --- a/cpp-ch/local-engine/Parser/TypeParser.h +++ b/cpp-ch/local-engine/Parser/TypeParser.h @@ -42,7 +42,8 @@ public: return parseType(substrait_type, nullptr); } - static DB::Block buildBlockFromNamedStruct(const substrait::NamedStruct & struct_); + // low_card_cols is in format of "cola,colb". Currently does not nested column to be LowCardinality. + static DB::Block buildBlockFromNamedStruct(const substrait::NamedStruct & struct_, const std::string& low_card_cols = ""); /// Build block from substrait NamedStruct without DFS rules, different from buildBlockFromNamedStruct static DB::Block buildBlockFromNamedStructWithoutDFS(const substrait::NamedStruct & struct_); diff --git a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp index e9d148bca..28c57c31f 100644 --- a/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp +++ b/cpp-ch/local-engine/Shuffle/ShuffleSplitter.cpp @@ -301,7 +301,7 @@ void ColumnsBuffer::appendSelective( accumulated_columns.reserve(source.columns()); for (size_t i = 0; i < source.columns(); i++) { - auto column = source.getColumns()[i]->convertToFullColumnIfConst()->convertToFullColumnIfSparse()->cloneEmpty(); + auto column = source.getColumns()[i]->convertToFullColumnIfConst()->convertToFullIfNeeded()->cloneEmpty(); column->reserve(prefer_buffer_size); accumulated_columns.emplace_back(std::move(column)); } @@ -310,7 +310,7 @@ void ColumnsBuffer::appendSelective( if (!accumulated_columns[column_idx]->onlyNull()) { accumulated_columns[column_idx]->insertRangeSelective( - *source.getByPosition(column_idx).column->convertToFullColumnIfConst()->convertToFullColumnIfSparse(), selector, from, length); + *source.getByPosition(column_idx).column->convertToFullIfNeeded(), selector, from, length); } else { diff --git a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp index 39a0cb7b5..7f09721ab 100644 --- a/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp +++ b/cpp-ch/local-engine/Storages/IO/NativeWriter.cpp @@ -23,6 +23,7 @@ #include <Storages/IO/AggregateSerializationUtils.h> #include <Functions/FunctionHelpers.h> #include <DataTypes/DataTypeAggregateFunction.h> +#include <DataTypes/DataTypeLowCardinality.h> using namespace DB; @@ -72,6 +73,7 @@ size_t NativeWriter::write(const DB::Block & block) auto column = block.safeGetByPosition(i); /// agg state will convert to fixedString, need write actual agg state type auto original_type = header.safeGetByPosition(i).type; + original_type = recursiveRemoveLowCardinality(original_type); /// Type String type_name = original_type->getName(); bool is_agg_opt = WhichDataType(original_type).isAggregateFunction() @@ -85,8 +87,9 @@ size_t NativeWriter::write(const DB::Block & block) writeStringBinary(type_name, ostr); } - SerializationPtr serialization = column.type->getDefaultSerialization(); column.column = recursiveRemoveSparse(column.column); + column.column = recursiveRemoveLowCardinality(column.column); + SerializationPtr serialization = recursiveRemoveLowCardinality(column.type)->getDefaultSerialization(); /// Data if (rows) /// Zero items of data is always represented as zero number of bytes. { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java index 76b0a31c2..9dc26215b 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableBuilder.java @@ -30,6 +30,7 @@ public class ExtensionTableBuilder { String relativeTablePath, String absoluteTablePath, String orderByKey, + String lowCardKey, String primaryKey, List<String> partList, List<Long> starts, @@ -45,6 +46,7 @@ public class ExtensionTableBuilder { relativeTablePath, absoluteTablePath, orderByKey, + lowCardKey, primaryKey, partList, starts, diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java index b721477d8..bf942ef26 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExtensionTableNode.java @@ -43,6 +43,8 @@ public class ExtensionTableNode implements SplitInfo { private String primaryKey; + private String lowCardKey; + private List<String> partList; private List<Long> starts; private List<Long> lengths; @@ -57,6 +59,7 @@ public class ExtensionTableNode implements SplitInfo { String relativePath, String absolutePath, String orderByKey, + String lowCardKey, String primaryKey, List<String> partList, List<Long> starts, @@ -76,6 +79,7 @@ public class ExtensionTableNode implements SplitInfo { this.absolutePath = absolutePath; this.tableSchemaJson = tableSchemaJson; this.orderByKey = orderByKey; + this.lowCardKey = lowCardKey; this.primaryKey = primaryKey; this.partList = partList; this.starts = starts; @@ -110,6 +114,7 @@ public class ExtensionTableNode implements SplitInfo { if (!this.orderByKey.isEmpty() && !this.orderByKey.equals("tuple()")) { extensionTableStr.append(this.primaryKey).append("\n"); } + extensionTableStr.append(this.lowCardKey).append("\n"); extensionTableStr.append(this.relativePath).append("\n"); extensionTableStr.append(this.absolutePath).append("\n"); diff --git a/shims/common/src/main/scala/io/glutenproject/execution/datasource/GlutenFormatWriterInjects.scala b/shims/common/src/main/scala/io/glutenproject/execution/datasource/GlutenFormatWriterInjects.scala index d44a3dd22..856974a0e 100644 --- a/shims/common/src/main/scala/io/glutenproject/execution/datasource/GlutenFormatWriterInjects.scala +++ b/shims/common/src/main/scala/io/glutenproject/execution/datasource/GlutenFormatWriterInjects.scala @@ -42,6 +42,7 @@ trait GlutenFormatWriterInjects { database: String, tableName: String, orderByKeyOption: Option[Seq[String]], + lowCardKeyOption: Option[Seq[String]], primaryKeyOption: Option[Seq[String]], partitionColumns: Seq[String], tableSchema: StructType, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@gluten.apache.org For additional commands, e-mail: commits-h...@gluten.apache.org