This is an automated email from the ASF dual-hosted git repository.
biyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 4aecb01b07 [spark] Support ACCEPT_ANY_SCHEMA for spark v2 write (#6281)
4aecb01b07 is described below
commit 4aecb01b079eb0d4ee4889d51ca5057b7b13f51a
Author: Kerwin Zhang <[email protected]>
AuthorDate: Tue Sep 23 22:23:24 2025 +0800
[spark] Support ACCEPT_ANY_SCHEMA for spark v2 write (#6281)
---
.../paimon/spark/SparkInternalRowWrapper.java | 148 ++++++++--
.../scala/org/apache/paimon/spark/SparkTable.scala | 3 +-
.../paimon/spark/commands/SchemaHelper.scala | 34 ++-
.../apache/paimon/spark/write/PaimonV2Write.scala | 35 ++-
.../paimon/spark/write/PaimonV2WriteBuilder.scala | 5 +-
.../paimon/spark/sql/V2WriteMergeSchemaTest.scala | 319 +++++++++++++++++++++
6 files changed, 491 insertions(+), 53 deletions(-)
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
index 0d00495c69..7de1695af0 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkInternalRowWrapper.java
@@ -43,6 +43,8 @@ import org.apache.spark.sql.types.TimestampType;
import java.io.Serializable;
import java.math.BigDecimal;
+import java.util.HashMap;
+import java.util.Map;
/** Wrapper to fetch value from the spark internal row. */
public class SparkInternalRowWrapper implements InternalRow, Serializable {
@@ -50,23 +52,32 @@ public class SparkInternalRowWrapper implements
InternalRow, Serializable {
private transient org.apache.spark.sql.catalyst.InternalRow internalRow;
private final int length;
private final int rowKindIdx;
- private final StructType structType;
+ private final StructType tableSchema;
+ private int[] fieldIndexMap = null;
public SparkInternalRowWrapper(
org.apache.spark.sql.catalyst.InternalRow internalRow,
int rowKindIdx,
- StructType structType,
+ StructType tableSchema,
int length) {
this.internalRow = internalRow;
this.rowKindIdx = rowKindIdx;
this.length = length;
- this.structType = structType;
+ this.tableSchema = tableSchema;
}
- public SparkInternalRowWrapper(int rowKindIdx, StructType structType, int
length) {
+ public SparkInternalRowWrapper(int rowKindIdx, StructType tableSchema, int
length) {
this.rowKindIdx = rowKindIdx;
this.length = length;
- this.structType = structType;
+ this.tableSchema = tableSchema;
+ }
+
+ public SparkInternalRowWrapper(
+ int rowKindIdx, StructType tableSchema, StructType dataSchema, int
length) {
+ this.rowKindIdx = rowKindIdx;
+ this.length = length;
+ this.tableSchema = tableSchema;
+ this.fieldIndexMap = buildFieldIndexMap(tableSchema, dataSchema);
}
public SparkInternalRowWrapper
replace(org.apache.spark.sql.catalyst.InternalRow internalRow) {
@@ -74,6 +85,42 @@ public class SparkInternalRowWrapper implements InternalRow,
Serializable {
return this;
}
+ private int[] buildFieldIndexMap(StructType schemaStruct, StructType
dataSchema) {
+ int[] mapping = new int[schemaStruct.size()];
+
+ Map<String, Integer> rowFieldIndexMap = new HashMap<>();
+ for (int i = 0; i < dataSchema.size(); i++) {
+ rowFieldIndexMap.put(dataSchema.fields()[i].name(), i);
+ }
+
+ for (int i = 0; i < schemaStruct.size(); i++) {
+ String fieldName = schemaStruct.fields()[i].name();
+ Integer index = rowFieldIndexMap.get(fieldName);
+ mapping[i] = (index != null) ? index : -1;
+ }
+
+ return mapping;
+ }
+
+ private int getActualFieldPosition(int pos) {
+ if (fieldIndexMap == null) {
+ return pos;
+ } else {
+ if (pos < 0 || pos >= fieldIndexMap.length) {
+ return -1;
+ }
+ return fieldIndexMap[pos];
+ }
+ }
+
+ private int validateAndGetActualPosition(int pos) {
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1) {
+ throw new ArrayIndexOutOfBoundsException("Field index out of
bounds: " + pos);
+ }
+ return actualPos;
+ }
+
@Override
public int getFieldCount() {
return length;
@@ -82,10 +129,12 @@ public class SparkInternalRowWrapper implements
InternalRow, Serializable {
@Override
public RowKind getRowKind() {
if (rowKindIdx != -1) {
- return RowKind.fromByteValue(internalRow.getByte(rowKindIdx));
- } else {
- return RowKind.INSERT;
+ int actualPos = getActualFieldPosition(rowKindIdx);
+ if (actualPos != -1) {
+ return RowKind.fromByteValue(internalRow.getByte(actualPos));
+ }
}
+ return RowKind.INSERT;
}
@Override
@@ -95,69 +144,102 @@ public class SparkInternalRowWrapper implements
InternalRow, Serializable {
@Override
public boolean isNullAt(int pos) {
- return internalRow.isNullAt(pos);
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1) {
+ return true;
+ }
+ return internalRow.isNullAt(actualPos);
}
@Override
public boolean getBoolean(int pos) {
- return internalRow.getBoolean(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getBoolean(actualPos);
}
@Override
public byte getByte(int pos) {
- return internalRow.getByte(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getByte(actualPos);
}
@Override
public short getShort(int pos) {
- return internalRow.getShort(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getShort(actualPos);
}
@Override
public int getInt(int pos) {
- return internalRow.getInt(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getInt(actualPos);
}
@Override
public long getLong(int pos) {
- return internalRow.getLong(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getLong(actualPos);
}
@Override
public float getFloat(int pos) {
- return internalRow.getFloat(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getFloat(actualPos);
}
@Override
public double getDouble(int pos) {
- return internalRow.getDouble(pos);
+ int actualPos = validateAndGetActualPosition(pos);
+ return internalRow.getDouble(actualPos);
}
@Override
public BinaryString getString(int pos) {
- return
BinaryString.fromBytes(internalRow.getUTF8String(pos).getBytes());
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
+ return
BinaryString.fromBytes(internalRow.getUTF8String(actualPos).getBytes());
}
@Override
public Decimal getDecimal(int pos, int precision, int scale) {
- org.apache.spark.sql.types.Decimal decimal =
internalRow.getDecimal(pos, precision, scale);
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
+ org.apache.spark.sql.types.Decimal decimal =
+ internalRow.getDecimal(actualPos, precision, scale);
BigDecimal bigDecimal = decimal.toJavaBigDecimal();
return Decimal.fromBigDecimal(bigDecimal, precision, scale);
}
@Override
public Timestamp getTimestamp(int pos, int precision) {
- return convertToTimestamp(structType.fields()[pos].dataType(),
internalRow.getLong(pos));
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
+ return convertToTimestamp(
+ tableSchema.fields()[pos].dataType(),
internalRow.getLong(actualPos));
}
@Override
public byte[] getBinary(int pos) {
- return internalRow.getBinary(pos);
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
+ return internalRow.getBinary(actualPos);
}
@Override
public Variant getVariant(int pos) {
- return SparkShimLoader.shim().toPaimonVariant(internalRow, pos);
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
+ return SparkShimLoader.shim().toPaimonVariant(internalRow, actualPos);
}
@Override
@@ -167,24 +249,36 @@ public class SparkInternalRowWrapper implements
InternalRow, Serializable {
@Override
public InternalArray getArray(int pos) {
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
return new SparkInternalArray(
- internalRow.getArray(pos),
- ((ArrayType)
(structType.fields()[pos].dataType())).elementType());
+ internalRow.getArray(actualPos),
+ ((ArrayType)
(tableSchema.fields()[pos].dataType())).elementType());
}
@Override
public InternalMap getMap(int pos) {
- MapType mapType = (MapType) structType.fields()[pos].dataType();
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
+ MapType mapType = (MapType) tableSchema.fields()[pos].dataType();
return new SparkInternalMap(
- internalRow.getMap(pos), mapType.keyType(),
mapType.valueType());
+ internalRow.getMap(actualPos), mapType.keyType(),
mapType.valueType());
}
@Override
public InternalRow getRow(int pos, int numFields) {
+ int actualPos = getActualFieldPosition(pos);
+ if (actualPos == -1 || internalRow.isNullAt(actualPos)) {
+ return null;
+ }
return new SparkInternalRowWrapper(
- internalRow.getStruct(pos, numFields),
+ internalRow.getStruct(actualPos, numFields),
-1,
- (StructType) structType.fields()[pos].dataType(),
+ (StructType) tableSchema.fields()[actualPos].dataType(),
numFields);
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
index 305a7191d8..e79e148ebb 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/SparkTable.scala
@@ -108,6 +108,7 @@ case class SparkTable(table: Table)
)
if (useV2Write) {
+ capabilities.add(TableCapability.ACCEPT_ANY_SCHEMA)
capabilities.add(TableCapability.BATCH_WRITE)
capabilities.add(TableCapability.OVERWRITE_DYNAMIC)
} else {
@@ -152,7 +153,7 @@ case class SparkTable(table: Table)
case fileStoreTable: FileStoreTable =>
val options = Options.fromMap(info.options)
if (useV2Write) {
- new PaimonV2WriteBuilder(fileStoreTable, info.schema())
+ new PaimonV2WriteBuilder(fileStoreTable, info.schema(), options)
} else {
new PaimonWriteBuilder(fileStoreTable, options)
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
index d66a941929..06f749b8ec 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/SchemaHelper.scala
@@ -40,19 +40,8 @@ private[spark] trait SchemaHelper extends WithFileStoreTable
{
override def table: FileStoreTable = newTable.getOrElse(originTable)
def mergeSchema(sparkSession: SparkSession, input: DataFrame, options:
Options): DataFrame = {
- val mergeSchemaEnabled =
- options.get(SparkConnectorOptions.MERGE_SCHEMA) ||
OptionUtils.writeMergeSchemaEnabled()
- if (!mergeSchemaEnabled) {
- return input
- }
-
val dataSchema = SparkSystemColumns.filterSparkSystemColumns(input.schema)
- val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST)
|| OptionUtils
- .writeMergeSchemaExplicitCastEnabled()
- mergeAndCommitSchema(dataSchema, allowExplicitCast)
-
- // For case that some columns is absent in data, we still allow to write
once write.merge-schema is true.
- val newTableSchema =
SparkTypeUtils.fromPaimonRowType(table.schema().logicalRowType())
+ val newTableSchema = mergeSchema(input.schema, options)
if (!PaimonUtils.sameType(newTableSchema, dataSchema)) {
val resolve = sparkSession.sessionState.conf.resolver
val cols = newTableSchema.map {
@@ -68,6 +57,27 @@ private[spark] trait SchemaHelper extends WithFileStoreTable
{
}
}
+ def mergeSchema(dataSchema: StructType, options: Options): StructType = {
+ val mergeSchemaEnabled =
+ options.get(SparkConnectorOptions.MERGE_SCHEMA) ||
OptionUtils.writeMergeSchemaEnabled()
+ if (!mergeSchemaEnabled) {
+ return dataSchema
+ }
+
+ val filteredDataSchema =
SparkSystemColumns.filterSparkSystemColumns(dataSchema)
+ val allowExplicitCast = options.get(SparkConnectorOptions.EXPLICIT_CAST)
|| OptionUtils
+ .writeMergeSchemaExplicitCastEnabled()
+ mergeAndCommitSchema(filteredDataSchema, allowExplicitCast)
+
+ val writeSchema =
SparkTypeUtils.fromPaimonRowType(table.schema().logicalRowType())
+
+ if (!PaimonUtils.sameType(writeSchema, filteredDataSchema)) {
+ writeSchema
+ } else {
+ filteredDataSchema
+ }
+ }
+
private def mergeAndCommitSchema(dataSchema: StructType, allowExplicitCast:
Boolean): Unit = {
val dataRowType =
SparkTypeUtils.toPaimonType(dataSchema).asInstanceOf[RowType]
if (table.store().mergeSchema(dataRowType, allowExplicitCast)) {
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
index 8eaeffe2fc..9eaa1bf72f 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2Write.scala
@@ -19,7 +19,9 @@
package org.apache.paimon.spark.write
import org.apache.paimon.CoreOptions
+import org.apache.paimon.options.Options
import org.apache.paimon.spark.{SparkInternalRowWrapper, SparkUtils}
+import org.apache.paimon.spark.commands.SchemaHelper
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink.{BatchTableWrite, BatchWriteBuilder,
CommitMessage, CommitMessageSerializer}
@@ -36,21 +38,24 @@ import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}
class PaimonV2Write(
- storeTable: FileStoreTable,
+ override val originTable: FileStoreTable,
overwriteDynamic: Boolean,
overwritePartitions: Option[Map[String, String]],
- writeSchema: StructType
+ dataSchema: StructType,
+ options: Options
) extends Write
with RequiresDistributionAndOrdering
+ with SchemaHelper
with Logging {
assert(
!(overwriteDynamic && overwritePartitions.exists(_.nonEmpty)),
"Cannot overwrite dynamically and by filter both")
- private val table =
- storeTable.copy(
- Map(CoreOptions.DYNAMIC_PARTITION_OVERWRITE.key ->
overwriteDynamic.toString).asJava)
+ private val writeSchema = mergeSchema(dataSchema, options)
+
+ updateTableWithOptions(
+ Map(CoreOptions.DYNAMIC_PARTITION_OVERWRITE.key ->
overwriteDynamic.toString))
private val writeRequirement = PaimonWriteRequirement(table)
@@ -66,7 +71,8 @@ class PaimonV2Write(
ordering
}
- override def toBatch: BatchWrite = PaimonBatchWrite(table, writeSchema,
overwritePartitions)
+ override def toBatch: BatchWrite =
+ PaimonBatchWrite(table, writeSchema, dataSchema, overwritePartitions)
override def toString: String = {
val overwriteDynamicStr = if (overwriteDynamic) {
@@ -86,6 +92,7 @@ class PaimonV2Write(
private case class PaimonBatchWrite(
table: FileStoreTable,
writeSchema: StructType,
+ dataSchema: StructType,
overwritePartitions: Option[Map[String, String]])
extends BatchWrite
with WriteHelper {
@@ -97,7 +104,7 @@ private case class PaimonBatchWrite(
}
override def createBatchWriterFactory(info: PhysicalWriteInfo):
DataWriterFactory =
- WriterFactory(writeSchema, batchWriteBuilder)
+ WriterFactory(writeSchema, dataSchema, batchWriteBuilder)
override def useCommitCoordinator(): Boolean = false
@@ -129,16 +136,22 @@ private case class PaimonBatchWrite(
}
}
-private case class WriterFactory(writeSchema: StructType, batchWriteBuilder:
BatchWriteBuilder)
+private case class WriterFactory(
+ writeSchema: StructType,
+ dataSchema: StructType,
+ batchWriteBuilder: BatchWriteBuilder)
extends DataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long):
DataWriter[InternalRow] = {
val batchTableWrite = batchWriteBuilder.newWrite()
- new PaimonDataWriter(batchTableWrite, writeSchema)
+ new PaimonDataWriter(batchTableWrite, writeSchema, dataSchema)
}
}
-private class PaimonDataWriter(batchTableWrite: BatchTableWrite, writeSchema:
StructType)
+private class PaimonDataWriter(
+ batchTableWrite: BatchTableWrite,
+ writeSchema: StructType,
+ dataSchema: StructType)
extends DataWriter[InternalRow] {
private val ioManager = SparkUtils.createIOManager()
@@ -146,7 +159,7 @@ private class PaimonDataWriter(batchTableWrite:
BatchTableWrite, writeSchema: St
private val rowConverter: InternalRow => SparkInternalRowWrapper = {
val numFields = writeSchema.fields.length
- val reusableWrapper = new SparkInternalRowWrapper(-1, writeSchema,
numFields)
+ val reusableWrapper = new SparkInternalRowWrapper(-1, writeSchema,
dataSchema, numFields)
record => reusableWrapper.replace(record)
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
index 90f30a3955..d6b747a53f 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/write/PaimonV2WriteBuilder.scala
@@ -18,13 +18,14 @@
package org.apache.paimon.spark.write
+import org.apache.paimon.options.Options
import org.apache.paimon.table.FileStoreTable
import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite,
SupportsOverwrite, WriteBuilder}
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.types.StructType
-class PaimonV2WriteBuilder(table: FileStoreTable, writeSchema: StructType)
+class PaimonV2WriteBuilder(table: FileStoreTable, dataSchema: StructType,
options: Options)
extends BaseWriteBuilder(table)
with SupportsOverwrite
with SupportsDynamicOverwrite {
@@ -33,7 +34,7 @@ class PaimonV2WriteBuilder(table: FileStoreTable,
writeSchema: StructType)
private var overwritePartitions: Option[Map[String, String]] = None
override def build =
- new PaimonV2Write(table, overwriteDynamic, overwritePartitions,
writeSchema)
+ new PaimonV2Write(table, overwriteDynamic, overwritePartitions,
dataSchema, options)
override def overwrite(filters: Array[Filter]): WriteBuilder = {
if (overwriteDynamic) {
diff --git
a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala
new file mode 100644
index 0000000000..0b6e589d9f
--- /dev/null
+++
b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/V2WriteMergeSchemaTest.scala
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.Row
+
+class V2WriteMergeSchemaTest extends PaimonSparkTestBase {
+
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set("spark.sql.catalog.paimon.cache-enabled", "false")
+ .set("spark.paimon.write.use-v2-write", "true")
+ .set("spark.paimon.write.merge-schema", "true")
+ .set("spark.paimon.write.merge-schema.explicit-cast", "true")
+ }
+
+ import testImplicits._
+
+ test("Write merge schema: dataframe write") {
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ Seq((1, "1"), (2, "2"))
+ .toDF("a", "b")
+ .writeTo("t")
+ .option("write.merge-schema", "true")
+ .append()
+
+ // new columns
+ Seq((3, "3", 3))
+ .toDF("a", "b", "c")
+ .writeTo("t")
+ .option("write.merge-schema", "true")
+ .append()
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(Row(1, "1", null), Row(2, "2", null), Row(3, "3", 3))
+ )
+
+ // missing columns and new columns
+ Seq(("4", "4", 4))
+ .toDF("d", "b", "c")
+ .writeTo("t")
+ .option("write.merge-schema", "true")
+ .append()
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(null, "4", 4, "4"),
+ Row(1, "1", null, null),
+ Row(2, "2", null, null),
+ Row(3, "3", 3, null))
+ )
+ }
+ }
+
+ test("Write merge schema: sql write") {
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+ // new columns
+ sql("INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, 3 AS c")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(Row(1, "1", null), Row(2, "2", null), Row(3, "3", 3))
+ )
+
+ // missing columns and new columns
+ sql("INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, 4 AS c")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(null, "4", 4, "4"),
+ Row(1, "1", null, null),
+ Row(2, "2", null, null),
+ Row(3, "3", 3, null))
+ )
+ }
+ }
+
+ test("Write merge schema: fail when merge schema is disabled but new columns
are provided") {
+ withTable("t") {
+ withSparkSQLConf("spark.paimon.write.merge-schema" -> "false") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+ val error = intercept[RuntimeException] {
+ spark.sql("INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, 3 AS c")
+ }.getMessage
+ assert(error.contains("the number of data columns don't match with the
table schema's"))
+ }
+ }
+ }
+
+ test("Write merge schema: numeric types") {
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+ // new columns with numeric types
+ sql(
+ "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+ "cast(10 as byte) AS byte_col, " +
+ "cast(1000 as short) AS short_col, " +
+ "100000 AS int_col, " +
+ "10000000000L AS long_col, " +
+ "cast(1.23 as float) AS float_col, " +
+ "4.56 AS double_col, " +
+ "cast(7.89 as decimal(10,2)) AS decimal_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(1, "1", null, null, null, null, null, null, null),
+ Row(2, "2", null, null, null, null, null, null, null),
+ Row(
+ 3,
+ "3",
+ 10.toByte,
+ 1000.toShort,
+ 100000,
+ 10000000000L,
+ 1.23f,
+ 4.56d,
+ java.math.BigDecimal.valueOf(7.89))
+ )
+ )
+
+ // missing columns and new columns with numeric types
+ sql(
+ "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+ "cast(20 as byte) AS byte_col, " +
+ "cast(2000 as short) AS short_col, " +
+ "200000 AS int_col, " +
+ "20000000000L AS long_col, " +
+ "cast(2.34 as float) AS float_col, " +
+ "5.67 AS double_col, " +
+ "cast(8.96 as decimal(10,2)) AS decimal_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(
+ null,
+ "4",
+ 20.toByte,
+ 2000.toShort,
+ 200000,
+ 20000000000L,
+ 2.34f,
+ 5.67d,
+ java.math.BigDecimal.valueOf(8.96),
+ "4"),
+ Row(1, "1", null, null, null, null, null, null, null, null),
+ Row(2, "2", null, null, null, null, null, null, null, null),
+ Row(
+ 3,
+ "3",
+ 10.toByte,
+ 1000.toShort,
+ 100000,
+ 10000000000L,
+ 1.23f,
+ 4.56d,
+ java.math.BigDecimal.valueOf(7.89),
+ null)
+ )
+ )
+ }
+ }
+
+ test("Write merge schema: date and time types") {
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+ // new columns with date and time types
+ sql(
+ "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+ "cast('2023-01-01' as date) AS date_col, " +
+ "cast('2023-01-01 12:00:00' as timestamp) AS timestamp_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(1, "1", null, null),
+ Row(2, "2", null, null),
+ Row(
+ 3,
+ "3",
+ java.sql.Date.valueOf("2023-01-01"),
+ java.sql.Timestamp.valueOf("2023-01-01 12:00:00"))
+ )
+ )
+
+ // missing columns and new columns with date and time types
+ sql(
+ "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+ "cast('2023-12-31' as date) AS date_col, " +
+ "cast('2023-12-31 23:59:59' as timestamp) AS timestamp_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(
+ null,
+ "4",
+ java.sql.Date.valueOf("2023-12-31"),
+ java.sql.Timestamp.valueOf("2023-12-31 23:59:59"),
+ "4"),
+ Row(1, "1", null, null, null),
+ Row(2, "2", null, null, null),
+ Row(
+ 3,
+ "3",
+ java.sql.Date.valueOf("2023-01-01"),
+ java.sql.Timestamp.valueOf("2023-01-01 12:00:00"),
+ null)
+ )
+ )
+ }
+ }
+
+ test("Write merge schema: complex types") {
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+ // new columns with complex types
+ sql(
+ "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+ "array(1, 2, 3) AS array_col, " +
+ "map('key1', 'value1', 'key2', 'value2') AS map_col, " +
+ "struct('x', 1) AS struct_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(1, "1", null, null, null),
+ Row(2, "2", null, null, null),
+ Row(3, "3", Array(1, 2, 3), Map("key1" -> "value1", "key2" ->
"value2"), Row("x", 1))
+ )
+ )
+
+ // missing columns and new columns with complex types
+ sql(
+ "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+ "array(4, 5, 6) AS array_col, " +
+ "map('key3', 'value3') AS map_col, " +
+ "struct('y', 2) AS struct_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(null, "4", Array(4, 5, 6), Map("key3" -> "value3"), Row("y", 2),
"4"),
+ Row(1, "1", null, null, null, null),
+ Row(2, "2", null, null, null, null),
+ Row(
+ 3,
+ "3",
+ Array(1, 2, 3),
+ Map("key1" -> "value1", "key2" -> "value2"),
+ Row("x", 1),
+ null)
+ )
+ )
+ }
+ }
+
+ test("Write merge schema: binary and boolean types") {
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b STRING)")
+ sql("INSERT INTO t VALUES (1, '1'), (2, '2')")
+
+ // new columns with binary and boolean types
+ sql(
+ "INSERT INTO t BY NAME SELECT 3 AS a, '3' AS b, " +
+ "cast('binary_data' as binary) AS binary_col, " +
+ "true AS boolean_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(1, "1", null, null),
+ Row(2, "2", null, null),
+ Row(3, "3", "binary_data".getBytes("UTF-8"), true)
+ )
+ )
+
+ // missing columns and new columns with binary and boolean types
+ sql(
+ "INSERT INTO t BY NAME SELECT '4' AS d, '4' AS b, " +
+ "cast('more_data' as binary) AS binary_col, " +
+ "false AS boolean_col")
+ checkAnswer(
+ sql("SELECT * FROM t ORDER BY a"),
+ Seq(
+ Row(null, "4", "more_data".getBytes("UTF-8"), false, "4"),
+ Row(1, "1", null, null, null),
+ Row(2, "2", null, null, null),
+ Row(3, "3", "binary_data".getBytes("UTF-8"), true, null)
+ )
+ )
+ }
+ }
+
+}