This is an automated email from the ASF dual-hosted git repository. reuvenlax pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 08a84e4 [BEAM-6674] Add schema support to JdbcIO read (#8725) 08a84e4 is described below commit 08a84e4717ad184ef3526604851744ea9addd988 Author: Charith Ellawala <chari...@users.noreply.github.com> AuthorDate: Wed Jun 19 07:12:31 2019 +0100 [BEAM-6674] Add schema support to JdbcIO read (#8725) * Adds a readRows method to JdbcIO. * Add exclusion for SpotBugs false positive * Add logical types to maintain 1:1 mapping between JDBC and Beam types * Attach schema if type is in registry --- .../src/main/resources/beam/spotbugs-filter.xml | 13 +- .../sdk/io/jdbc/BeamSchemaInferenceException.java | 25 ++ .../java/org/apache/beam/sdk/io/jdbc/JdbcIO.java | 142 +++++++++ .../org/apache/beam/sdk/io/jdbc/LogicalTypes.java | 241 +++++++++++++++ .../org/apache/beam/sdk/io/jdbc/SchemaUtil.java | 340 +++++++++++++++++++++ .../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java | 68 +++++ .../org/apache/beam/sdk/io/jdbc/RowWithSchema.java | 44 +++ .../apache/beam/sdk/io/jdbc/SchemaUtilTest.java | 303 ++++++++++++++++++ 8 files changed, 1174 insertions(+), 2 deletions(-) diff --git a/sdks/java/build-tools/src/main/resources/beam/spotbugs-filter.xml b/sdks/java/build-tools/src/main/resources/beam/spotbugs-filter.xml index f46da65..7811398 100644 --- a/sdks/java/build-tools/src/main/resources/beam/spotbugs-filter.xml +++ b/sdks/java/build-tools/src/main/resources/beam/spotbugs-filter.xml @@ -417,8 +417,17 @@ This is a false positive. Spotbugs does not recognize the use of try-with-resources, so it thinks that the connection is not correctly closed. --> - <Class name="org.apache.beam.sdk.io.jdbc.JdbcIO$ReadFn"/> - <Method name="processElement"/> + <Or> + <And> + <Class name="org.apache.beam.sdk.io.jdbc.JdbcIO$ReadFn"/> + <Method name="processElement"/> + </And> + <And> + <Class name="org.apache.beam.sdk.io.jdbc.JdbcIO$ReadRows"/> + <Method name="inferBeamSchema"/> + </And> + </Or> + <Bug pattern="OBL_UNSATISFIED_OBLIGATION"/> </Match> </FindBugsFilter> diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/BeamSchemaInferenceException.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/BeamSchemaInferenceException.java new file mode 100644 index 0000000..683144d --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/BeamSchemaInferenceException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +/** Exception to signal that inferring the Beam schema from the JDBC source failed. */ +public class BeamSchemaInferenceException extends RuntimeException { + public BeamSchemaInferenceException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 6d6b5e0..2574c45 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -33,7 +33,11 @@ import javax.annotation.Nullable; import javax.sql.DataSource; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Filter; @@ -54,6 +58,8 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbcp2.DataSourceConnectionFactory; import org.apache.commons.dbcp2.PoolableConnectionFactory; @@ -188,6 +194,15 @@ public class JdbcIO { .build(); } + /** Read Beam {@link Row}s from a JDBC data source. */ + @Experimental(Experimental.Kind.SCHEMAS) + public static ReadRows readRows() { + return new AutoValue_JdbcIO_ReadRows.Builder() + .setFetchSize(DEFAULT_FETCH_SIZE) + .setOutputParallelization(true) + .build(); + } + /** * Like {@link #read}, but executes multiple instances of the query substituting each element of a * {@link PCollection} as query parameters. @@ -391,6 +406,123 @@ public class JdbcIO { void setParameters(PreparedStatement preparedStatement) throws Exception; } + /** Implementation of {@link #readRows()}. */ + @AutoValue + @Experimental(Experimental.Kind.SCHEMAS) + public abstract static class ReadRows extends PTransform<PBegin, PCollection<Row>> { + @Nullable + abstract SerializableFunction<Void, DataSource> getDataSourceProviderFn(); + + @Nullable + abstract ValueProvider<String> getQuery(); + + @Nullable + abstract StatementPreparator getStatementPreparator(); + + abstract int getFetchSize(); + + abstract boolean getOutputParallelization(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setDataSourceProviderFn( + SerializableFunction<Void, DataSource> dataSourceProviderFn); + + abstract Builder setQuery(ValueProvider<String> query); + + abstract Builder setStatementPreparator(StatementPreparator statementPreparator); + + abstract Builder setFetchSize(int fetchSize); + + abstract Builder setOutputParallelization(boolean outputParallelization); + + abstract ReadRows build(); + } + + public ReadRows withDataSourceProviderFn( + SerializableFunction<Void, DataSource> dataSourceProviderFn) { + return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build(); + } + + public ReadRows withQuery(String query) { + checkArgument(query != null, "query can not be null"); + return withQuery(ValueProvider.StaticValueProvider.of(query)); + } + + public ReadRows withQuery(ValueProvider<String> query) { + checkArgument(query != null, "query can not be null"); + return toBuilder().setQuery(query).build(); + } + + public ReadRows withStatementPreparator(StatementPreparator statementPreparator) { + checkArgument(statementPreparator != null, "statementPreparator can not be null"); + return toBuilder().setStatementPreparator(statementPreparator).build(); + } + + /** + * This method is used to set the size of the data that is going to be fetched and loaded in + * memory per every database call. Please refer to: {@link java.sql.Statement#setFetchSize(int)} + * It should ONLY be used if the default value throws memory errors. + */ + public ReadRows withFetchSize(int fetchSize) { + checkArgument(fetchSize > 0, "fetch size must be > 0"); + return toBuilder().setFetchSize(fetchSize).build(); + } + + /** + * Whether to reshuffle the resulting PCollection so results are distributed to all workers. The + * default is to parallelize and should only be changed if this is known to be unnecessary. + */ + public ReadRows withOutputParallelization(boolean outputParallelization) { + return toBuilder().setOutputParallelization(outputParallelization).build(); + } + + @Override + public PCollection<Row> expand(PBegin input) { + checkArgument(getQuery() != null, "withQuery() is required"); + checkArgument( + (getDataSourceProviderFn() != null), + "withDataSourceConfiguration() or withDataSourceProviderFn() is required"); + + Schema schema = inferBeamSchema(); + PCollection<Row> rows = + input.apply( + JdbcIO.<Row>read() + .withDataSourceProviderFn(getDataSourceProviderFn()) + .withQuery(getQuery()) + .withCoder(RowCoder.of(schema)) + .withRowMapper(SchemaUtil.BeamRowMapper.of(schema)) + .withFetchSize(getFetchSize()) + .withOutputParallelization(getOutputParallelization()) + .withStatementPreparator(getStatementPreparator())); + rows.setRowSchema(schema); + return rows; + } + + private Schema inferBeamSchema() { + DataSource ds = getDataSourceProviderFn().apply(null); + try (Connection conn = ds.getConnection(); + PreparedStatement statement = + conn.prepareStatement( + getQuery().get(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)) { + return SchemaUtil.toBeamSchema(statement.getMetaData()); + } catch (SQLException e) { + throw new BeamSchemaInferenceException("Failed to infer Beam schema", e); + } + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + builder.add(DisplayData.item("query", getQuery())); + if (getDataSourceProviderFn() instanceof HasDisplayData) { + ((HasDisplayData) getDataSourceProviderFn()).populateDisplayData(builder); + } + } + } + /** Implementation of {@link #read}. */ @AutoValue public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> { @@ -671,6 +803,16 @@ public class JdbcIO { output = output.apply(new Reparallelize<>()); } + try { + TypeDescriptor<OutputT> typeDesc = getCoder().getEncodedTypeDescriptor(); + SchemaRegistry registry = input.getPipeline().getSchemaRegistry(); + Schema schema = registry.getSchema(typeDesc); + output.setSchema( + schema, registry.getToRowFunction(typeDesc), registry.getFromRowFunction(typeDesc)); + } catch (NoSuchSchemaException e) { + // ignore + } + return output; } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/LogicalTypes.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/LogicalTypes.java new file mode 100644 index 0000000..e024911 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/LogicalTypes.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; + +import java.math.BigDecimal; +import java.sql.JDBCType; +import java.time.Instant; +import java.util.Arrays; +import java.util.Objects; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.StringUtils; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting; + +/** Beam {@link org.apache.beam.sdk.schemas.Schema.LogicalType} implementations of JDBC types. */ +public class LogicalTypes { + public static final Schema.FieldType JDBC_BIT_TYPE = + Schema.FieldType.logicalType( + new org.apache.beam.sdk.schemas.LogicalTypes.PassThroughLogicalType<Boolean>( + JDBCType.BIT.getName(), "", Schema.FieldType.BOOLEAN) {}); + + public static final Schema.FieldType JDBC_DATE_TYPE = + Schema.FieldType.logicalType( + new org.apache.beam.sdk.schemas.LogicalTypes.PassThroughLogicalType<Instant>( + JDBCType.DATE.getName(), "", Schema.FieldType.DATETIME) {}); + + public static final Schema.FieldType JDBC_FLOAT_TYPE = + Schema.FieldType.logicalType( + new org.apache.beam.sdk.schemas.LogicalTypes.PassThroughLogicalType<Double>( + JDBCType.FLOAT.getName(), "", Schema.FieldType.DOUBLE) {}); + + public static final Schema.FieldType JDBC_TIME_TYPE = + Schema.FieldType.logicalType( + new org.apache.beam.sdk.schemas.LogicalTypes.PassThroughLogicalType<Instant>( + JDBCType.TIME.getName(), "", Schema.FieldType.DATETIME) {}); + + public static final Schema.FieldType JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE = + Schema.FieldType.logicalType( + new org.apache.beam.sdk.schemas.LogicalTypes.PassThroughLogicalType<Instant>( + JDBCType.TIMESTAMP_WITH_TIMEZONE.getName(), "", Schema.FieldType.DATETIME) {}); + + @VisibleForTesting + static Schema.FieldType fixedLengthString(JDBCType jdbcType, int length) { + return Schema.FieldType.logicalType(FixedLengthString.of(jdbcType.getName(), length)); + } + + @VisibleForTesting + static Schema.FieldType fixedLengthBytes(JDBCType jdbcType, int length) { + return Schema.FieldType.logicalType(FixedLengthBytes.of(jdbcType.getName(), length)); + } + + @VisibleForTesting + static Schema.FieldType variableLengthString(JDBCType jdbcType, int length) { + return Schema.FieldType.logicalType(VariableLengthString.of(jdbcType.getName(), length)); + } + + @VisibleForTesting + static Schema.FieldType variableLengthBytes(JDBCType jdbcType, int length) { + return Schema.FieldType.logicalType(VariableLengthBytes.of(jdbcType.getName(), length)); + } + + @VisibleForTesting + static Schema.FieldType numeric(int precision, int scale) { + return Schema.FieldType.logicalType( + FixedPrecisionNumeric.of(JDBCType.NUMERIC.getName(), precision, scale)); + } + + /** Base class for JDBC logical types. */ + public abstract static class JdbcLogicalType<T> implements Schema.LogicalType<T, T> { + protected final String identifier; + protected final Schema.FieldType baseType; + protected final String argument; + + protected JdbcLogicalType(String identifier, Schema.FieldType baseType, String argument) { + this.identifier = identifier; + this.baseType = baseType; + this.argument = argument; + } + + @Override + public String getIdentifier() { + return identifier; + } + + @Override + public String getArgument() { + return argument; + } + + @Override + public Schema.FieldType getBaseType() { + return baseType; + } + + @Override + public T toBaseType(T input) { + return input; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof JdbcLogicalType)) { + return false; + } + JdbcLogicalType<?> that = (JdbcLogicalType<?>) o; + return Objects.equals(identifier, that.identifier) + && Objects.equals(baseType, that.baseType) + && Objects.equals(argument, that.argument); + } + + @Override + public int hashCode() { + return Objects.hash(identifier, baseType, argument); + } + } + + /** Fixed length string types such as CHAR. */ + public static final class FixedLengthString extends JdbcLogicalType<String> { + private final int length; + + public static FixedLengthString of(String identifier, int length) { + return new FixedLengthString(identifier, length); + } + + private FixedLengthString(String identifier, int length) { + super(identifier, Schema.FieldType.STRING, String.valueOf(length)); + this.length = length; + } + + @Override + public String toInputType(String base) { + checkArgument(base == null || base.length() <= length); + return StringUtils.rightPad(base, length); + } + } + + /** Fixed length byte types such as BINARY. */ + public static final class FixedLengthBytes extends JdbcLogicalType<byte[]> { + private final int length; + + public static FixedLengthBytes of(String identifier, int length) { + return new FixedLengthBytes(identifier, length); + } + + private FixedLengthBytes(String identifier, int length) { + super(identifier, Schema.FieldType.BYTES, String.valueOf(length)); + this.length = length; + } + + @Override + public byte[] toInputType(byte[] base) { + checkArgument(base == null || base.length <= length); + if (base == null || base.length == length) { + return base; + } else { + return Arrays.copyOf(base, length); + } + } + } + + /** Variable length string types such as VARCHAR and LONGVARCHAR. */ + public static final class VariableLengthString extends JdbcLogicalType<String> { + private final int maxLength; + + public static VariableLengthString of(String identifier, int maxLength) { + return new VariableLengthString(identifier, maxLength); + } + + private VariableLengthString(String identifier, int maxLength) { + super(identifier, Schema.FieldType.STRING, String.valueOf(maxLength)); + this.maxLength = maxLength; + } + + @Override + public String toInputType(String base) { + checkArgument(base == null || base.length() <= maxLength); + return base; + } + } + + /** Variable length bytes types such as VARBINARY and LONGVARBINARY. */ + public static final class VariableLengthBytes extends JdbcLogicalType<byte[]> { + private final int maxLength; + + public static VariableLengthBytes of(String identifier, int maxLength) { + return new VariableLengthBytes(identifier, maxLength); + } + + private VariableLengthBytes(String identifier, int maxLength) { + super(identifier, Schema.FieldType.BYTES, String.valueOf(maxLength)); + this.maxLength = maxLength; + } + + @Override + public byte[] toInputType(byte[] base) { + checkArgument(base == null || base.length <= maxLength); + return base; + } + } + + /** Fixed precision numeric types such as NUMERIC. */ + public static final class FixedPrecisionNumeric extends JdbcLogicalType<BigDecimal> { + private final int precision; + private final int scale; + + public static FixedPrecisionNumeric of(String identifier, int precision, int scale) { + return new FixedPrecisionNumeric(identifier, precision, scale); + } + + private FixedPrecisionNumeric(String identifier, int precision, int scale) { + super(identifier, Schema.FieldType.DECIMAL, precision + ":" + scale); + this.precision = precision; + this.scale = scale; + } + + @Override + public BigDecimal toInputType(BigDecimal base) { + checkArgument(base == null || (base.precision() == precision && base.scale() == scale)); + return base; + } + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java new file mode 100644 index 0000000..263c2de --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import static java.sql.JDBCType.BINARY; +import static java.sql.JDBCType.CHAR; +import static java.sql.JDBCType.LONGNVARCHAR; +import static java.sql.JDBCType.LONGVARBINARY; +import static java.sql.JDBCType.LONGVARCHAR; +import static java.sql.JDBCType.NCHAR; +import static java.sql.JDBCType.NUMERIC; +import static java.sql.JDBCType.NVARCHAR; +import static java.sql.JDBCType.VARBINARY; +import static java.sql.JDBCType.VARCHAR; +import static java.sql.JDBCType.valueOf; + +import java.io.Serializable; +import java.sql.Array; +import java.sql.Date; +import java.sql.JDBCType; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.EnumMap; +import java.util.List; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; +import org.joda.time.DateTime; +import org.joda.time.chrono.ISOChronology; + +/** Provides utility functions for working with Beam {@link Schema} types. */ +class SchemaUtil { + /** + * Interface implemented by functions that extract values of different types from a JDBC + * ResultSet. + */ + @FunctionalInterface + interface ResultSetFieldExtractor extends Serializable { + Object extract(ResultSet rs, Integer index) throws SQLException; + } + + // ResultSetExtractors for primitive schema types (excluding arrays, structs and logical types). + private static final EnumMap<Schema.TypeName, ResultSetFieldExtractor> + RESULTSET_FIELD_EXTRACTORS = + new EnumMap<>( + ImmutableMap.<Schema.TypeName, ResultSetFieldExtractor>builder() + .put(Schema.TypeName.BOOLEAN, ResultSet::getBoolean) + .put(Schema.TypeName.BYTE, ResultSet::getByte) + .put(Schema.TypeName.BYTES, ResultSet::getBytes) + .put(Schema.TypeName.DATETIME, ResultSet::getTimestamp) + .put(Schema.TypeName.DECIMAL, ResultSet::getBigDecimal) + .put(Schema.TypeName.DOUBLE, ResultSet::getDouble) + .put(Schema.TypeName.FLOAT, ResultSet::getFloat) + .put(Schema.TypeName.INT16, ResultSet::getShort) + .put(Schema.TypeName.INT32, ResultSet::getInt) + .put(Schema.TypeName.INT64, ResultSet::getLong) + .put(Schema.TypeName.STRING, ResultSet::getString) + .build()); + + private static final ResultSetFieldExtractor DATE_EXTRACTOR = createDateExtractor(); + private static final ResultSetFieldExtractor TIME_EXTRACTOR = createTimeExtractor(); + private static final ResultSetFieldExtractor TIMESTAMP_EXTRACTOR = createTimestampExtractor(); + + /** + * Interface implemented by functions that create Beam {@link + * org.apache.beam.sdk.schemas.Schema.Field} corresponding to JDBC field metadata. + */ + @FunctionalInterface + interface BeamFieldConverter extends Serializable { + Schema.Field create(int index, ResultSetMetaData md) throws SQLException; + } + + private static BeamFieldConverter jdbcTypeToBeamFieldConverter(JDBCType jdbcType) { + switch (jdbcType) { + case ARRAY: + return beamArrayField(); + case BIGINT: + return beamFieldOfType(Schema.FieldType.INT64); + case BINARY: + return beamLogicalField(BINARY.getName(), LogicalTypes.FixedLengthBytes::of); + case BIT: + return beamFieldOfType(LogicalTypes.JDBC_BIT_TYPE); + case BOOLEAN: + return beamFieldOfType(Schema.FieldType.BOOLEAN); + case CHAR: + return beamLogicalField(CHAR.getName(), LogicalTypes.FixedLengthString::of); + case DATE: + return beamFieldOfType(LogicalTypes.JDBC_DATE_TYPE); + case DECIMAL: + return beamFieldOfType(Schema.FieldType.DECIMAL); + case DOUBLE: + return beamFieldOfType(Schema.FieldType.DOUBLE); + case FLOAT: + return beamFieldOfType(LogicalTypes.JDBC_FLOAT_TYPE); + case INTEGER: + return beamFieldOfType(Schema.FieldType.INT32); + case LONGNVARCHAR: + return beamLogicalField(LONGNVARCHAR.getName(), LogicalTypes.VariableLengthString::of); + case LONGVARBINARY: + return beamLogicalField(LONGVARBINARY.getName(), LogicalTypes.VariableLengthBytes::of); + case LONGVARCHAR: + return beamLogicalField(LONGVARCHAR.getName(), LogicalTypes.VariableLengthString::of); + case NCHAR: + return beamLogicalField(NCHAR.getName(), LogicalTypes.FixedLengthString::of); + case NUMERIC: + return beamLogicalNumericField(NUMERIC.getName()); + case NVARCHAR: + return beamLogicalField(NVARCHAR.getName(), LogicalTypes.VariableLengthString::of); + case REAL: + return beamFieldOfType(Schema.FieldType.FLOAT); + case SMALLINT: + return beamFieldOfType(Schema.FieldType.INT16); + case TIME: + return beamFieldOfType(LogicalTypes.JDBC_TIME_TYPE); + case TIMESTAMP: + return beamFieldOfType(Schema.FieldType.DATETIME); + case TIMESTAMP_WITH_TIMEZONE: + return beamFieldOfType(LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE); + case TINYINT: + return beamFieldOfType(Schema.FieldType.BYTE); + case VARBINARY: + return beamLogicalField(VARBINARY.getName(), LogicalTypes.VariableLengthBytes::of); + case VARCHAR: + return beamLogicalField(VARCHAR.getName(), LogicalTypes.VariableLengthString::of); + default: + throw new UnsupportedOperationException( + "Converting " + jdbcType + " to Beam schema type is not supported"); + } + } + + /** Infers the Beam {@link Schema} from {@link ResultSetMetaData}. */ + static Schema toBeamSchema(ResultSetMetaData md) throws SQLException { + Schema.Builder schemaBuilder = Schema.builder(); + + for (int i = 1; i <= md.getColumnCount(); i++) { + JDBCType jdbcType = valueOf(md.getColumnType(i)); + BeamFieldConverter fieldConverter = jdbcTypeToBeamFieldConverter(jdbcType); + schemaBuilder.addField(fieldConverter.create(i, md)); + } + + return schemaBuilder.build(); + } + + /** Converts a primitive JDBC field to corresponding Beam schema field. */ + private static BeamFieldConverter beamFieldOfType(Schema.FieldType fieldType) { + return (index, md) -> { + String label = md.getColumnLabel(index); + return Schema.Field.of(label, fieldType) + .withNullable(md.isNullable(index) == ResultSetMetaData.columnNullable); + }; + } + + /** Converts logical types with arguments such as VARCHAR(25). */ + private static <InputT, BaseT> BeamFieldConverter beamLogicalField( + String identifier, + BiFunction<String, Integer, Schema.LogicalType<InputT, BaseT>> constructor) { + return (index, md) -> { + int size = md.getPrecision(index); + Schema.FieldType fieldType = + Schema.FieldType.logicalType(constructor.apply(identifier, size)); + return beamFieldOfType(fieldType).create(index, md); + }; + } + + /** Converts numeric fields with specified precision and scale. */ + private static BeamFieldConverter beamLogicalNumericField(String identifier) { + return (index, md) -> { + int precision = md.getPrecision(index); + int scale = md.getScale(index); + Schema.FieldType fieldType = + Schema.FieldType.logicalType( + LogicalTypes.FixedPrecisionNumeric.of(identifier, precision, scale)); + return beamFieldOfType(fieldType).create(index, md); + }; + } + + /** Converts array fields. */ + private static BeamFieldConverter beamArrayField() { + return (index, md) -> { + JDBCType elementJdbcType = valueOf(md.getColumnTypeName(index)); + BeamFieldConverter elementFieldConverter = jdbcTypeToBeamFieldConverter(elementJdbcType); + + String label = md.getColumnLabel(index); + Schema.FieldType elementBeamType = elementFieldConverter.create(index, md).getType(); + return Schema.Field.of(label, Schema.FieldType.array(elementBeamType)) + .withNullable(md.isNullable(index) == ResultSetMetaData.columnNullable); + }; + } + + /** Creates a {@link ResultSetFieldExtractor} for the given type. */ + private static ResultSetFieldExtractor createFieldExtractor(Schema.FieldType fieldType) { + Schema.TypeName typeName = fieldType.getTypeName(); + switch (typeName) { + case ARRAY: + Schema.FieldType elementType = fieldType.getCollectionElementType(); + ResultSetFieldExtractor elementExtractor = createFieldExtractor(elementType); + return createArrayExtractor(elementExtractor); + case DATETIME: + return TIMESTAMP_EXTRACTOR; + case LOGICAL_TYPE: + return createLogicalTypeExtractor(fieldType.getLogicalType()); + default: + if (!RESULTSET_FIELD_EXTRACTORS.containsKey(typeName)) { + throw new UnsupportedOperationException( + "BeamRowMapper does not have support for fields of type " + fieldType.toString()); + } + return RESULTSET_FIELD_EXTRACTORS.get(typeName); + } + } + + /** Creates a {@link ResultSetFieldExtractor} for array types. */ + private static ResultSetFieldExtractor createArrayExtractor( + ResultSetFieldExtractor elementExtractor) { + return (rs, index) -> { + Array arrayVal = rs.getArray(index); + if (arrayVal == null) { + return null; + } + + List<Object> arrayElements = new ArrayList<>(); + ResultSet arrayRs = arrayVal.getResultSet(); + while (arrayRs.next()) { + arrayElements.add(elementExtractor.extract(arrayRs, 1)); + } + return arrayElements; + }; + } + + /** Creates a {@link ResultSetFieldExtractor} for logical types. */ + private static <InputT, BaseT> ResultSetFieldExtractor createLogicalTypeExtractor( + final Schema.LogicalType<InputT, BaseT> fieldType) { + String logicalTypeName = fieldType.getIdentifier(); + JDBCType underlyingType = JDBCType.valueOf(logicalTypeName); + switch (underlyingType) { + case DATE: + return DATE_EXTRACTOR; + case TIME: + return TIME_EXTRACTOR; + case TIMESTAMP_WITH_TIMEZONE: + return TIMESTAMP_EXTRACTOR; + default: + ResultSetFieldExtractor extractor = createFieldExtractor(fieldType.getBaseType()); + return (rs, index) -> fieldType.toInputType((BaseT) extractor.extract(rs, index)); + } + } + + /** Convert SQL date type to Beam DateTime. */ + private static ResultSetFieldExtractor createDateExtractor() { + return (rs, i) -> { + Date date = rs.getDate(i); + if (date == null) { + return null; + } + ZonedDateTime zdt = ZonedDateTime.of(date.toLocalDate(), LocalTime.MIDNIGHT, ZoneOffset.UTC); + return new DateTime(zdt.toInstant().toEpochMilli(), ISOChronology.getInstanceUTC()); + }; + } + + /** Convert SQL time type to Beam DateTime. */ + private static ResultSetFieldExtractor createTimeExtractor() { + return (rs, i) -> { + Time time = rs.getTime(i); + if (time == null) { + return null; + } + ZonedDateTime zdt = + ZonedDateTime.of(LocalDate.ofEpochDay(0), time.toLocalTime(), ZoneOffset.systemDefault()); + return new DateTime(zdt.toInstant().toEpochMilli(), ISOChronology.getInstanceUTC()); + }; + } + + /** Convert SQL timestamp type to Beam DateTime. */ + private static ResultSetFieldExtractor createTimestampExtractor() { + return (rs, i) -> { + Timestamp ts = rs.getTimestamp(i); + if (ts == null) { + return null; + } + return new DateTime(ts.toInstant().toEpochMilli(), ISOChronology.getInstanceUTC()); + }; + } + + /** + * A {@link org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper} implementation that converts JDBC + * results into Beam {@link Row} objects. + */ + static final class BeamRowMapper implements JdbcIO.RowMapper<Row> { + private final Schema schema; + private final List<ResultSetFieldExtractor> fieldExtractors; + + public static BeamRowMapper of(Schema schema) { + List<ResultSetFieldExtractor> fieldExtractors = + IntStream.range(0, schema.getFieldCount()) + .mapToObj(i -> createFieldExtractor(schema.getField(i).getType())) + .collect(Collectors.toList()); + + return new BeamRowMapper(schema, fieldExtractors); + } + + private BeamRowMapper(Schema schema, List<ResultSetFieldExtractor> fieldExtractors) { + this.schema = schema; + this.fieldExtractors = fieldExtractors; + } + + @Override + public Row mapRow(ResultSet rs) throws Exception { + Row.Builder rowBuilder = Row.withSchema(schema); + for (int i = 0; i < schema.getFieldCount(); i++) { + rowBuilder.addValue(fieldExtractors.get(i).extract(rs, i + 1)); + } + return rowBuilder.build(); + } + } +} diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 82df39b..ecd26a1 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.jdbc; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.PrintWriter; @@ -24,6 +25,7 @@ import java.io.Serializable; import java.io.StringWriter; import java.net.InetAddress; import java.sql.Connection; +import java.sql.JDBCType; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -38,14 +40,19 @@ import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.common.DatabaseTestHelper; import org.apache.beam.sdk.io.common.NetworkTestHelper; import org.apache.beam.sdk.io.common.TestRow; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.transforms.Select; import org.apache.beam.sdk.testing.ExpectedLogs; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Wait; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList; import org.apache.commons.dbcp2.PoolingDataSource; import org.apache.derby.drda.NetworkServerControl; import org.apache.derby.jdbc.ClientDataSource; @@ -273,6 +280,67 @@ public class JdbcIOTest implements Serializable { } @Test + public void testReadRows() { + SerializableFunction<Void, DataSource> dataSourceProvider = ignored -> dataSource; + PCollection<Row> rows = + pipeline.apply( + JdbcIO.readRows() + .withDataSourceProviderFn(dataSourceProvider) + .withQuery(String.format("select name,id from %s where name = ?", readTableName)) + .withStatementPreparator( + preparedStatement -> + preparedStatement.setString(1, TestRow.getNameForSeed(1)))); + + Schema expectedSchema = + Schema.of( + Schema.Field.of("NAME", LogicalTypes.variableLengthString(JDBCType.VARCHAR, 500)) + .withNullable(true), + Schema.Field.of("ID", Schema.FieldType.INT32).withNullable(true)); + + assertEquals(expectedSchema, rows.getSchema()); + + PCollection<Row> output = rows.apply(Select.fieldNames("NAME", "ID")); + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of(Row.withSchema(expectedSchema).addValues("Testval1", 1).build())); + + pipeline.run(); + } + + @Test + public void testReadWithSchema() { + SerializableFunction<Void, DataSource> dataSourceProvider = ignored -> dataSource; + JdbcIO.RowMapper<RowWithSchema> rowMapper = + rs -> new RowWithSchema(rs.getString("NAME"), rs.getInt("ID")); + pipeline.getSchemaRegistry().registerJavaBean(RowWithSchema.class); + + PCollection<RowWithSchema> rows = + pipeline.apply( + JdbcIO.<RowWithSchema>read() + .withDataSourceProviderFn(dataSourceProvider) + .withQuery(String.format("select name,id from %s where name = ?", readTableName)) + .withRowMapper(rowMapper) + .withCoder(SerializableCoder.of(RowWithSchema.class)) + .withStatementPreparator( + preparedStatement -> + preparedStatement.setString(1, TestRow.getNameForSeed(1)))); + + Schema expectedSchema = + Schema.of( + Schema.Field.of("name", Schema.FieldType.STRING), + Schema.Field.of("id", Schema.FieldType.INT32)); + + assertEquals(expectedSchema, rows.getSchema()); + + PCollection<Row> output = rows.apply(Select.fieldNames("name", "id")); + PAssert.that(output) + .containsInAnyOrder( + ImmutableList.of(Row.withSchema(expectedSchema).addValues("Testval1", 1).build())); + + pipeline.run(); + } + + @Test public void testWrite() throws Exception { final long rowsToAdd = 1000L; diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/RowWithSchema.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/RowWithSchema.java new file mode 100644 index 0000000..a175216 --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/RowWithSchema.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import java.io.Serializable; +import org.apache.beam.sdk.schemas.JavaBeanSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaCreate; + +/** Test row. */ +@DefaultSchema(JavaBeanSchema.class) +public class RowWithSchema implements Serializable { + private final String name; + private final int id; + + @SchemaCreate + public RowWithSchema(String name, int id) { + this.name = name; + this.id = id; + } + + public String getName() { + return name; + } + + public int getId() { + return id; + } +} diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java new file mode 100644 index 0000000..63b3d88 --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.math.BigDecimal; +import java.nio.charset.Charset; +import java.sql.Array; +import java.sql.Date; +import java.sql.JDBCType; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList; +import org.joda.time.DateTime; +import org.joda.time.LocalDate; +import org.joda.time.chrono.ISOChronology; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test SchemaUtils. */ +@RunWith(JUnit4.class) +public class SchemaUtilTest { + @Test + public void testToBeamSchema() throws SQLException { + ResultSetMetaData mockResultSetMetaData = mock(ResultSetMetaData.class); + + ImmutableList<JdbcFieldInfo> fieldInfo = + ImmutableList.of( + JdbcFieldInfo.of("int_array_col", Types.ARRAY, JDBCType.INTEGER.getName(), false), + JdbcFieldInfo.of("bigint_col", Types.BIGINT), + JdbcFieldInfo.of("binary_col", Types.BINARY, 255), + JdbcFieldInfo.of("bit_col", Types.BIT), + JdbcFieldInfo.of("boolean_col", Types.BOOLEAN), + JdbcFieldInfo.of("char_col", Types.CHAR, 255), + JdbcFieldInfo.of("date_col", Types.DATE), + JdbcFieldInfo.of("decimal_col", Types.DECIMAL), + JdbcFieldInfo.of("double_col", Types.DOUBLE), + JdbcFieldInfo.of("float_col", Types.FLOAT), + JdbcFieldInfo.of("integer_col", Types.INTEGER), + JdbcFieldInfo.of("longnvarchar_col", Types.LONGNVARCHAR, 1024), + JdbcFieldInfo.of("longvarchar_col", Types.LONGVARCHAR, 1024), + JdbcFieldInfo.of("longvarbinary_col", Types.LONGVARBINARY, 1024), + JdbcFieldInfo.of("nchar_col", Types.NCHAR, 255), + JdbcFieldInfo.of("numeric_col", Types.NUMERIC, 12, 4), + JdbcFieldInfo.of("nvarchar_col", Types.NVARCHAR, 255), + JdbcFieldInfo.of("real_col", Types.REAL), + JdbcFieldInfo.of("smallint_col", Types.SMALLINT), + JdbcFieldInfo.of("time_col", Types.TIME), + JdbcFieldInfo.of("timestamp_col", Types.TIMESTAMP), + JdbcFieldInfo.of("timestamptz_col", Types.TIMESTAMP_WITH_TIMEZONE), + JdbcFieldInfo.of("tinyint_col", Types.TINYINT), + JdbcFieldInfo.of("varbinary_col", Types.VARBINARY, 255), + JdbcFieldInfo.of("varchar_col", Types.VARCHAR, 255)); + + when(mockResultSetMetaData.getColumnCount()).thenReturn(fieldInfo.size()); + for (int i = 0; i < fieldInfo.size(); i++) { + JdbcFieldInfo f = fieldInfo.get(i); + when(mockResultSetMetaData.getColumnLabel(eq(i + 1))).thenReturn(f.columnLabel); + when(mockResultSetMetaData.getColumnType(eq(i + 1))).thenReturn(f.columnType); + when(mockResultSetMetaData.getColumnTypeName(eq(i + 1))).thenReturn(f.columnTypeName); + when(mockResultSetMetaData.getPrecision(eq(i + 1))).thenReturn(f.precision); + when(mockResultSetMetaData.getScale(eq(i + 1))).thenReturn(f.scale); + when(mockResultSetMetaData.isNullable(eq(i + 1))) + .thenReturn( + f.nullable ? ResultSetMetaData.columnNullable : ResultSetMetaData.columnNoNulls); + } + + Schema wantBeamSchema = + Schema.builder() + .addArrayField("int_array_col", Schema.FieldType.INT32) + .addField("bigint_col", Schema.FieldType.INT64) + .addField("binary_col", LogicalTypes.fixedLengthBytes(JDBCType.BINARY, 255)) + .addField("bit_col", LogicalTypes.JDBC_BIT_TYPE) + .addField("boolean_col", Schema.FieldType.BOOLEAN) + .addField("char_col", LogicalTypes.fixedLengthString(JDBCType.CHAR, 255)) + .addField("date_col", LogicalTypes.JDBC_DATE_TYPE) + .addField("decimal_col", Schema.FieldType.DECIMAL) + .addField("double_col", Schema.FieldType.DOUBLE) + .addField("float_col", LogicalTypes.JDBC_FLOAT_TYPE) + .addField("integer_col", Schema.FieldType.INT32) + .addField( + "longnvarchar_col", LogicalTypes.variableLengthString(JDBCType.LONGNVARCHAR, 1024)) + .addField( + "longvarchar_col", LogicalTypes.variableLengthString(JDBCType.LONGVARCHAR, 1024)) + .addField( + "longvarbinary_col", LogicalTypes.variableLengthBytes(JDBCType.LONGVARBINARY, 1024)) + .addField("nchar_col", LogicalTypes.fixedLengthString(JDBCType.NCHAR, 255)) + .addField("numeric_col", LogicalTypes.numeric(12, 4)) + .addField("nvarchar_col", LogicalTypes.variableLengthString(JDBCType.NVARCHAR, 255)) + .addField("real_col", Schema.FieldType.FLOAT) + .addField("smallint_col", Schema.FieldType.INT16) + .addField("time_col", LogicalTypes.JDBC_TIME_TYPE) + .addField("timestamp_col", Schema.FieldType.DATETIME) + .addField("timestamptz_col", LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE) + .addField("tinyint_col", Schema.FieldType.BYTE) + .addField("varbinary_col", LogicalTypes.variableLengthBytes(JDBCType.VARBINARY, 255)) + .addField("varchar_col", LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255)) + .build(); + + Schema haveBeamSchema = SchemaUtil.toBeamSchema(mockResultSetMetaData); + assertEquals(wantBeamSchema, haveBeamSchema); + } + + @Test + public void testBeamRowMapper_array() throws Exception { + ResultSet mockArrayElementsResultSet = mock(ResultSet.class); + when(mockArrayElementsResultSet.next()).thenReturn(true, true, true, false); + when(mockArrayElementsResultSet.getInt(eq(1))).thenReturn(10, 20, 30); + + Array mockArray = mock(Array.class); + when(mockArray.getResultSet()).thenReturn(mockArrayElementsResultSet); + + ResultSet mockResultSet = mock(ResultSet.class); + when(mockResultSet.getArray(eq(1))).thenReturn(mockArray); + + Schema wantSchema = + Schema.builder().addField("array", Schema.FieldType.array(Schema.FieldType.INT32)).build(); + Row wantRow = + Row.withSchema(wantSchema).addValues((Object) ImmutableList.of(10, 20, 30)).build(); + + SchemaUtil.BeamRowMapper beamRowMapper = SchemaUtil.BeamRowMapper.of(wantSchema); + Row haveRow = beamRowMapper.mapRow(mockResultSet); + + assertEquals(wantRow, haveRow); + } + + @Test + public void testBeamRowMapper_primitiveTypes() throws Exception { + ResultSet mockResultSet = mock(ResultSet.class); + when(mockResultSet.getLong(eq(1))).thenReturn(42L); + when(mockResultSet.getBytes(eq(2))).thenReturn("binary".getBytes(Charset.forName("UTF-8"))); + when(mockResultSet.getBoolean(eq(3))).thenReturn(true); + when(mockResultSet.getBoolean(eq(4))).thenReturn(false); + when(mockResultSet.getString(eq(5))).thenReturn("char"); + when(mockResultSet.getBigDecimal(eq(6))).thenReturn(BigDecimal.valueOf(25L)); + when(mockResultSet.getDouble(eq(7))).thenReturn(20.5D); + when(mockResultSet.getFloat(eq(8))).thenReturn(15.5F); + when(mockResultSet.getInt(eq(9))).thenReturn(10); + when(mockResultSet.getString(eq(10))).thenReturn("longvarchar"); + when(mockResultSet.getBytes(eq(11))) + .thenReturn("longvarbinary".getBytes(Charset.forName("UTF-8"))); + when(mockResultSet.getBigDecimal(eq(12))).thenReturn(BigDecimal.valueOf(1000L)); + when(mockResultSet.getFloat(eq(13))).thenReturn(32F); + when(mockResultSet.getShort(eq(14))).thenReturn((short) 8); + when(mockResultSet.getShort(eq(15))).thenReturn((short) 4); + when(mockResultSet.getBytes(eq(16))).thenReturn("varbinary".getBytes(Charset.forName("UTF-8"))); + when(mockResultSet.getString(eq(17))).thenReturn("varchar"); + + Schema wantSchema = + Schema.builder() + .addField("bigint_col", Schema.FieldType.INT64) + .addField("binary_col", Schema.FieldType.BYTES) + .addField("bit_col", Schema.FieldType.BOOLEAN) + .addField("boolean_col", Schema.FieldType.BOOLEAN) + .addField("char_col", Schema.FieldType.STRING) + .addField("decimal_col", Schema.FieldType.DECIMAL) + .addField("double_col", Schema.FieldType.DOUBLE) + .addField("float_col", Schema.FieldType.FLOAT) + .addField("integer_col", Schema.FieldType.INT32) + .addField("longvarchar_col", Schema.FieldType.STRING) + .addField("longvarbinary_col", Schema.FieldType.BYTES) + .addField("numeric_col", Schema.FieldType.DECIMAL) + .addField("real_col", Schema.FieldType.FLOAT) + .addField("smallint_col", Schema.FieldType.INT16) + .addField("tinyint_col", Schema.FieldType.INT16) + .addField("varbinary_col", Schema.FieldType.BYTES) + .addField("varchar_col", Schema.FieldType.STRING) + .build(); + Row wantRow = + Row.withSchema(wantSchema) + .addValues( + 42L, + "binary".getBytes(Charset.forName("UTF-8")), + true, + false, + "char", + BigDecimal.valueOf(25L), + 20.5D, + 15.5F, + 10, + "longvarchar", + "longvarbinary".getBytes(Charset.forName("UTF-8")), + BigDecimal.valueOf(1000L), + 32F, + (short) 8, + (short) 4, + "varbinary".getBytes(Charset.forName("UTF-8")), + "varchar") + .build(); + + SchemaUtil.BeamRowMapper beamRowMapper = SchemaUtil.BeamRowMapper.of(wantSchema); + Row haveRow = beamRowMapper.mapRow(mockResultSet); + + assertEquals(wantRow, haveRow); + } + + @Test + public void testBeamRowMapper_datetime() throws Exception { + long epochMilli = 1558719710000L; + + ResultSet mockResultSet = mock(ResultSet.class); + when(mockResultSet.getDate(eq(1))).thenReturn(new Date(epochMilli)); + when(mockResultSet.getTime(eq(2))).thenReturn(new Time(epochMilli)); + when(mockResultSet.getTimestamp(eq(3))).thenReturn(new Timestamp(epochMilli)); + when(mockResultSet.getTimestamp(eq(4))).thenReturn(new Timestamp(epochMilli)); + + Schema wantSchema = + Schema.builder() + .addField("date_col", LogicalTypes.JDBC_DATE_TYPE) + .addField("time_col", LogicalTypes.JDBC_TIME_TYPE) + .addField("timestamptz_col", LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE) + .addField("timestamp_col", Schema.FieldType.DATETIME) + .build(); + + DateTime wantDateTime = new DateTime(epochMilli, ISOChronology.getInstanceUTC()); + + Row wantRow = + Row.withSchema(wantSchema) + .addValues( + wantDateTime.withTimeAtStartOfDay(), + wantDateTime.withDate(new LocalDate(0L)), + wantDateTime, + wantDateTime) + .build(); + + SchemaUtil.BeamRowMapper beamRowMapper = SchemaUtil.BeamRowMapper.of(wantSchema); + Row haveRow = beamRowMapper.mapRow(mockResultSet); + + assertEquals(wantRow, haveRow); + } + + //////////////////////////////////////////////////////////////////////////////////////// + private static final class JdbcFieldInfo { + private final String columnLabel; + private final int columnType; + private final String columnTypeName; + private final boolean nullable; + private final int precision; + private final int scale; + + private JdbcFieldInfo( + String columnLabel, + int columnType, + String columnTypeName, + boolean nullable, + int precision, + int scale) { + this.columnLabel = columnLabel; + this.columnType = columnType; + this.columnTypeName = columnTypeName; + this.nullable = nullable; + this.precision = precision; + this.scale = scale; + } + + private static JdbcFieldInfo of( + String columnLabel, int columnType, String columnTypeName, boolean nullable) { + return new JdbcFieldInfo(columnLabel, columnType, columnTypeName, nullable, 0, 0); + } + + private static JdbcFieldInfo of(String columnLabel, int columnType, boolean nullable) { + return new JdbcFieldInfo(columnLabel, columnType, null, nullable, 0, 0); + } + + private static JdbcFieldInfo of(String columnLabel, int columnType) { + return new JdbcFieldInfo(columnLabel, columnType, null, false, 0, 0); + } + + private static JdbcFieldInfo of(String columnLabel, int columnType, int precision) { + return new JdbcFieldInfo(columnLabel, columnType, null, false, precision, 0); + } + + private static JdbcFieldInfo of(String columnLabel, int columnType, int precision, int scale) { + return new JdbcFieldInfo(columnLabel, columnType, null, false, precision, scale); + } + } +}