gemini-code-assist[bot] commented on code in PR #38212:
URL: https://github.com/apache/beam/pull/38212#discussion_r3092412901
##########
runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java:
##########
@@ -522,7 +522,9 @@ JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>>
groupByKeyAndWindow(
Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K,
Itr<I>>>*/ List<byte[]>>>>
firedStream =
pairDStream.updateStateByKey(
- updateFunc,
+ // Raw cast to AbstractFunction1 suppresses Scala 2.12
(collection.Seq) vs
+ // Scala 2.13 (immutable.Seq) type difference — safe at
runtime due to erasure.
+ (scala.runtime.AbstractFunction1) updateFunc,
Review Comment:

Casting `updateFunc` to `scala.runtime.AbstractFunction1` will cause a
`ClassCastException` at runtime because `updateFunc` is a 2-argument function
(`Function2`). To suppress the `Seq` type difference between Scala 2.12 and
2.13 while maintaining correctness, use a raw cast to `scala.Function2` instead.
```suggestion
(scala.Function2) updateFunc,
```
##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java:
##########
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.io;
+
+import static java.util.stream.Collectors.toList;
+import static
org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static
org.apache.beam.sdk.values.WindowedValues.timestampedValueInGlobalWindow;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static scala.collection.JavaConverters.asScalaIterator;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntSupplier;
+import java.util.function.Supplier;
+import javax.annotation.CheckForNull;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.values.WindowedValue;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
+import org.apache.spark.InterruptibleIterator;
+import org.apache.spark.Partition;
+import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.classic.Dataset$;
+import org.apache.spark.sql.connector.catalog.SupportsRead;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.catalog.TableCapability;
+import org.apache.spark.sql.connector.read.Batch;
+import org.apache.spark.sql.connector.read.InputPartition;
+import org.apache.spark.sql.connector.read.PartitionReader;
+import org.apache.spark.sql.connector.read.PartitionReaderFactory;
+import org.apache.spark.sql.connector.read.Scan;
+import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+
+public class BoundedDatasetFactory {
+ private BoundedDatasetFactory() {}
+
+ /**
+ * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link
Table}.
+ *
+ * <p>Unfortunately tables are expected to return an {@link InternalRow},
requiring serialization.
+ * This makes this approach at the time being significantly less performant
than creating a
+ * dataset from an RDD.
+ */
+ public static <T> Dataset<WindowedValue<T>> createDatasetFromRows(
+ SparkSession session,
+ BoundedSource<T> source,
+ Supplier<PipelineOptions> options,
+ Encoder<WindowedValue<T>> encoder) {
+ Params<T> params = new Params<>(encoder, options,
session.sparkContext().defaultParallelism());
+ BeamTable<T> table = new BeamTable<>(source, params);
+ LogicalPlan logicalPlan = DataSourceV2Relation.create(table,
Option.empty(), Option.empty());
+ // In Spark 4.0+, Dataset$ moved to org.apache.spark.sql.classic; cast
session accordingly.
+ return (Dataset<WindowedValue<T>>)
+ Dataset$.MODULE$
+ .ofRows((org.apache.spark.sql.classic.SparkSession) session,
logicalPlan)
+ .as(encoder);
+ }
+
+ /**
+ * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link
RDD}.
+ *
+ * <p>This is currently the most efficient approach as it avoid any
serialization overhead.
+ */
+ public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD(
+ SparkSession session,
+ BoundedSource<T> source,
+ Supplier<PipelineOptions> options,
+ Encoder<WindowedValue<T>> encoder) {
+ Params<T> params = new Params<>(encoder, options,
session.sparkContext().defaultParallelism());
+ RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(),
source, params);
+ return session.createDataset(rdd, encoder);
+ }
+
+ /** An {@link RDD} for a bounded Beam source. */
+ private static class BoundedRDD<T> extends RDD<WindowedValue<T>> {
+ final BoundedSource<T> source;
+ final Params<T> params;
+
+ public BoundedRDD(SparkContext sc, BoundedSource<T> source, Params<T>
params) {
+ super(sc, emptyList(), ClassTag.apply(WindowedValue.class));
+ this.source = source;
+ this.params = params;
+ }
+
+ @Override
+ public Iterator<WindowedValue<T>> compute(Partition split, TaskContext
context) {
+ return new InterruptibleIterator<>(
+ context,
+ asScalaIterator(new SourcePartitionIterator<>((SourcePartition<T>)
split, params)));
+ }
+
+ @Override
+ public Partition[] getPartitions() {
+ return SourcePartition.partitionsOf(source, params).toArray(new
Partition[0]);
+ }
+ }
+
+ /** A Spark {@link Table} for a bounded Beam source supporting batch reads
only. */
+ private static class BeamTable<T> implements Table, SupportsRead {
+ final BoundedSource<T> source;
+ final Params<T> params;
+
+ BeamTable(BoundedSource<T> source, Params<T> params) {
+ this.source = source;
+ this.params = params;
+ }
+
+ public Encoder<WindowedValue<T>> getEncoder() {
+ return params.encoder;
+ }
+
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) {
+ return () ->
+ new Scan() {
+ @Override
+ public StructType readSchema() {
+ return params.encoder.schema();
+ }
+
+ @Override
+ public Batch toBatch() {
+ return new BeamBatch<>(source, params);
+ }
+ };
+ }
+
+ @Override
+ public String name() {
+ return "BeamSource<" + source.getClass().getName() + ">";
+ }
+
+ @Override
+ public StructType schema() {
+ return params.encoder.schema();
+ }
+
+ @Override
+ public Set<TableCapability> capabilities() {
+ return ImmutableSet.of(TableCapability.BATCH_READ);
+ }
+
+ private static class BeamBatch<T> implements Batch, Serializable {
+ final BoundedSource<T> source;
+ final Params<T> params;
+
+ private BeamBatch(BoundedSource<T> source, Params<T> params) {
+ this.source = source;
+ this.params = params;
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ return SourcePartition.partitionsOf(source, params).toArray(new
InputPartition[0]);
+ }
+
+ @Override
+ public PartitionReaderFactory createReaderFactory() {
+ return p -> new BeamPartitionReader<>(((SourcePartition<T>) p),
params);
+ }
+ }
+
+ private static class BeamPartitionReader<T> implements
PartitionReader<InternalRow> {
+ final SourcePartitionIterator<T> iterator;
+ final Serializer<WindowedValue<T>> serializer;
+ transient @Nullable InternalRow next;
+
+ BeamPartitionReader(SourcePartition<T> partition, Params<T> params) {
+ iterator = new SourcePartitionIterator<>(partition, params);
+ serializer = ((ExpressionEncoder<WindowedValue<T>>)
params.encoder).createSerializer();
+ }
+
+ @Override
+ public boolean next() throws IOException {
+ if (iterator.hasNext()) {
+ next = serializer.apply(iterator.next());
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public InternalRow get() {
+ if (next == null) {
+ throw new IllegalStateException("Next not available");
+ }
+ return next;
+ }
+
+ @Override
+ public void close() throws IOException {
+ next = null;
+ iterator.close();
+ }
+ }
+ }
+
+ /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */
+ private static class SourcePartition<T> implements Partition, InputPartition
{
+ final BoundedSource<T> source;
+ final int index;
+
+ SourcePartition(BoundedSource<T> source, IntSupplier idxSupplier) {
+ this.source = source;
+ this.index = idxSupplier.getAsInt();
+ }
+
+ static <T> List<SourcePartition<T>> partitionsOf(BoundedSource<T> source,
Params<T> params) {
+ try {
+ PipelineOptions options = params.options.get();
+ long desiredSize = source.getEstimatedSizeBytes(options) /
params.numPartitions;
+ List<BoundedSource<T>> split = (List<BoundedSource<T>>)
source.split(desiredSize, options);
+ IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement;
+ return split.stream().map(s -> new SourcePartition<>(s,
idxSupplier)).collect(toList());
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Error splitting BoundedSource " +
source.getClass().getCanonicalName(), e);
+ }
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public int hashCode() {
+ return index;
+ }
+ }
+
+ /** A partition iterator on a partitioned Beam {@link BoundedSource}. */
+ private static class SourcePartitionIterator<T> extends
AbstractIterator<WindowedValue<T>>
+ implements Closeable {
+ BoundedReader<T> reader;
+ boolean started = false;
+
+ public SourcePartitionIterator(SourcePartition<T> partition, Params<T>
params) {
+ try {
+ reader = partition.source.createReader(params.options.get());
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create reader from a
BoundedSource.", e);
+ }
+ }
+
+ @Override
+ @SuppressWarnings("nullness") // ok, reader not used any longer
+ public void close() throws IOException {
+ if (reader != null) {
+ endOfData();
+ try {
+ reader.close();
Review Comment:

The call to `endOfData()` here is unnecessary. `endOfData()` is a method in
Guava's `AbstractIterator` that returns a sentinel value to signal the end of
iteration; it is not a state-changing method that needs to be called inside
`close()`. The iterator state is already correctly handled in `computeNext()`
when it returns the result of `endOfData()`.
##########
runners/spark/4/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java:
##########
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming;
+
+import static
org.apache.beam.runners.core.metrics.MetricsContainerStepMap.asAttemptedOnlyMetricResults;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.metrics.MetricResults;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.spark.SparkException;
+import org.joda.time.Duration;
+
+public class SparkStructuredStreamingPipelineResult implements PipelineResult {
+
+ private final Future<?> pipelineExecution;
+ private final MetricsAccumulator metrics;
+ private @Nullable final Runnable onTerminalState;
+ private PipelineResult.State state;
+
+ SparkStructuredStreamingPipelineResult(
+ Future<?> pipelineExecution,
+ MetricsAccumulator metrics,
+ @Nullable final Runnable onTerminalState) {
+ this.pipelineExecution = pipelineExecution;
+ this.metrics = metrics;
+ this.onTerminalState = onTerminalState;
+ // pipelineExecution is expected to have started executing eagerly.
+ this.state = State.RUNNING;
+ }
+
+ private static RuntimeException runtimeExceptionFrom(final Throwable e) {
+ return (e instanceof RuntimeException) ? (RuntimeException) e : new
RuntimeException(e);
+ }
+
+ /**
+ * Unwrap cause of SparkException or UserCodeException as
PipelineExecutionException. Otherwise,
+ * return {@code exception} as RuntimeException.
+ */
+ private static RuntimeException unwrapCause(Throwable exception) {
+ Throwable next = exception;
+ while (next != null && (next instanceof SparkException || next instanceof
UserCodeException)) {
+ exception = next;
+ next = next.getCause();
+ }
+ return exception == next
+ ? runtimeExceptionFrom(exception)
+ : new Pipeline.PipelineExecutionException(firstNonNull(next,
exception));
+ }
+
+ private State awaitTermination(Duration duration)
+ throws TimeoutException, ExecutionException, InterruptedException {
+ pipelineExecution.get(duration.getMillis(), TimeUnit.MILLISECONDS);
+ // Throws an exception if the job is not finished successfully in the
given time.
+ return PipelineResult.State.DONE;
+ }
+
+ @Override
+ public PipelineResult.State getState() {
+ return state;
+ }
+
+ @Override
+ public PipelineResult.State waitUntilFinish() {
+ return waitUntilFinish(Duration.millis(Long.MAX_VALUE));
+ }
+
+ @Override
+ public State waitUntilFinish(final Duration duration) {
+ try {
+ State finishState = awaitTermination(duration);
+ offerNewState(finishState);
+ } catch (final TimeoutException e) {
+ // ignore.
+ } catch (final ExecutionException e) {
+ offerNewState(PipelineResult.State.FAILED);
+ throw unwrapCause(firstNonNull(e.getCause(), e));
+ } catch (final Exception e) {
+ offerNewState(PipelineResult.State.FAILED);
+ throw unwrapCause(e);
+ }
+
+ return state;
+ }
+
+ @Override
+ public MetricResults metrics() {
+ return asAttemptedOnlyMetricResults(metrics.value());
+ }
+
+ @Override
+ public PipelineResult.State cancel() throws IOException {
+ offerNewState(PipelineResult.State.CANCELLED);
+ return state;
+ }
Review Comment:

The `cancel()` method should attempt to cancel the underlying
`pipelineExecution` future. This is necessary to interrupt the driver thread
that is triggering Spark actions (e.g., in `EvaluationContext.evaluate`),
especially when `onTerminalState` is null (which happens when an active Spark
session is used and the runner is configured not to stop it). Without this, the
Spark jobs may continue to run even after the pipeline state is set to
`CANCELLED`.
```java
public PipelineResult.State cancel() throws IOException {
pipelineExecution.cancel(true);
offerNewState(PipelineResult.State.CANCELLED);
return state;
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]