scwhittle commented on code in PR #32775:
URL: https://github.com/apache/beam/pull/32775#discussion_r1804316412
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java:
##########
@@ -280,29 +300,30 @@ private synchronized void
consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
}
/** Close the streams that are no longer valid asynchronously. */
- @SuppressWarnings("FutureReturnValueIgnored")
- private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) {
+ private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
StreamingEngineBackends currentBackends = backends.get();
- ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams =
- currentBackends.windmillStreams();
- currentWindmillStreams.entrySet().stream()
+ currentBackends.windmillStreams().entrySet().stream()
.filter(
connectionAndStream ->
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
.forEach(
- entry ->
- CompletableFuture.runAsync(
- () -> closeStreamSender(entry.getKey(), entry.getValue()),
- windmillStreamManager));
+ entry -> {
+ CompletableFuture<Void> ignored =
Review Comment:
is this any different than just executing directly? if not it seems simpler
to avoid the future.
windmillStreamManager.execute(
() -> closeStreamSender(entry.getKey(), entry.getValue()))
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java:
##########
@@ -445,10 +400,14 @@ private void waitForBudgetDistribution() throws
InterruptedException {
getWorkBudgetDistributorTriggered.await(5, TimeUnit.SECONDS);
Review Comment:
seems like this should return value so tests can assert if they are awaiting?
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
Review Comment:
This is a lot of variants of the create
can we get rid of this one? easy for test to call new ThrottleTimer if they
don't care.
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ WorkItemScheduler workItemScheduler) {
+ return createGetWorkStream(
+ testStub, initialGetWorkBudget, new ThrottleTimer(),
workItemScheduler);
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer) {
+ return createGetWorkStream(
+ testStub,
+ initialGetWorkBudget,
+ throttleTimer,
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) ->
{});
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ // Header and extension.
+ assertThat(requestObserver.sent()).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
Review Comment:
woudl be better to verify the header as well.
ditto for others
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ WorkItemScheduler workItemScheduler) {
+ return createGetWorkStream(
+ testStub, initialGetWorkBudget, new ThrottleTimer(),
workItemScheduler);
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer) {
+ return createGetWorkStream(
+ testStub,
+ initialGetWorkBudget,
+ throttleTimer,
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) ->
{});
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ // Header and extension.
+ assertThat(requestObserver.sent()).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+ .isEqualTo(extension(newBudget));
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream = createGetWorkStream(testStub, initialBudget);
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+ stream.setBudget(newBudget);
+ GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Header and extension.
+ assertThat(requests).hasSize(expectedRequests);
+
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+ }
+
+ @Test
+ public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
Review Comment:
this maybe will never succeed since getRequest will return the default
object? Or perhaps it throws?
in either case you could just the type of the one-of via
assertTrue(...hasRequest()) or something similar. But it could be better to
have a matcher that you can use here and other tests that lets you verify the
budget within the initial request.
ditto for other instanceof checks
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java:
##########
@@ -40,147 +38,39 @@ public class EvenGetWorkBudgetDistributorTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
- private static GetWorkBudgetDistributor
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
- return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
- }
+ private static GetWorkBudgetSpender createGetWorkBudgetOwner() {
+ // Lambdas are final and cannot be spied.
+ return spy(
+ new GetWorkBudgetSpender() {
- private static GetWorkBudgetDistributor createBudgetDistributor(long
activeWorkItemsAndBytes) {
- return createBudgetDistributor(
- GetWorkBudget.builder()
- .setItems(activeWorkItemsAndBytes)
- .setBytes(activeWorkItemsAndBytes)
- .build());
+ @Override
+ public void setBudget(long items, long bytes) {}
+ });
}
@Test
public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
- createBudgetDistributor(1L)
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(
ImmutableList.of(),
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
}
@Test
public void testDistributeBudget_doesNothingWithNoBudget() {
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()));
- createBudgetDistributor(1L)
+ GetWorkBudgetSpender getWorkBudgetSpender =
spy(createGetWorkBudgetOwner());
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(ImmutableList.of(getWorkBudgetSpender),
GetWorkBudget.noBudget());
verifyNoInteractions(getWorkBudgetSpender);
}
- @Test
- public void
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighNoActiveWork()
{
- GetWorkBudgetSpender getWorkBudgetSpender =
- spy(
- createGetWorkBudgetOwnerWithRemainingBudgetOf(
- GetWorkBudget.builder().setItems(10L).setBytes(10L).build()));
- createBudgetDistributor(0L)
- .distributeBudget(
- ImmutableList.of(getWorkBudgetSpender),
- GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
-
- verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
- }
-
- @Test
- public void
-
testDistributeBudget_doesNotAdjustStreamBudgetWhenRemainingBudgetHighWithActiveWork()
{
- GetWorkBudgetSpender getWorkBudgetSpender =
- spy(
- createGetWorkBudgetOwnerWithRemainingBudgetOf(
- GetWorkBudget.builder().setItems(5L).setBytes(5L).build()));
- createBudgetDistributor(10L)
- .distributeBudget(
- ImmutableList.of(getWorkBudgetSpender),
- GetWorkBudget.builder().setItems(20L).setBytes(20L).build());
-
- verify(getWorkBudgetSpender, never()).adjustBudget(anyLong(), anyLong());
- }
-
- @Test
- public void
-
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithNoActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(0L)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
- eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
- }
-
- @Test
- public void
-
testDistributeBudget_adjustsStreamBudgetWhenRemainingItemBudgetTooLowWithActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(1L).setBytes(10L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- long activeWorkItemsAndBytes = 2L;
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(activeWorkItemsAndBytes)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(
- totalGetWorkBudget.items()
- - streamRemainingBudget.items()
- - activeWorkItemsAndBytes),
- eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
- }
-
- @Test
- public void
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowNoActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(0L)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
- eq(totalGetWorkBudget.bytes() - streamRemainingBudget.bytes()));
- }
-
- @Test
- public void
-
testDistributeBudget_adjustsStreamBudgetWhenRemainingByteBudgetTooLowWithActiveWork()
{
- GetWorkBudget streamRemainingBudget =
- GetWorkBudget.builder().setItems(10L).setBytes(1L).build();
- GetWorkBudget totalGetWorkBudget =
GetWorkBudget.builder().setItems(10L).setBytes(10L).build();
- long activeWorkItemsAndBytes = 2L;
-
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(streamRemainingBudget));
- createBudgetDistributor(activeWorkItemsAndBytes)
- .distributeBudget(ImmutableList.of(getWorkBudgetSpender),
totalGetWorkBudget);
-
- verify(getWorkBudgetSpender, times(1))
- .adjustBudget(
- eq(totalGetWorkBudget.items() - streamRemainingBudget.items()),
- eq(
- totalGetWorkBudget.bytes()
- - streamRemainingBudget.bytes()
- - activeWorkItemsAndBytes));
- }
-
@Test
public void testDistributeBudget_distributesBudgetEvenlyIfPossible() {
long totalItemsAndBytes = 10L;
Review Comment:
would be better to have different items and bytes values to confirm
distributor doesn't mix them up internally
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributor.java:
##########
@@ -61,38 +43,15 @@ public <T extends GetWorkBudgetSpender> void
distributeBudget(
return;
}
- Map<T, GetWorkBudget> desiredBudgets = computeDesiredBudgets(budgetOwners,
getWorkBudget);
-
- for (Entry<T, GetWorkBudget> streamAndDesiredBudget :
desiredBudgets.entrySet()) {
- GetWorkBudgetSpender getWorkBudgetSpender =
streamAndDesiredBudget.getKey();
- GetWorkBudget desired = streamAndDesiredBudget.getValue();
- GetWorkBudget remaining = getWorkBudgetSpender.remainingBudget();
- if (isBelowFiftyPercentOfTarget(remaining, desired)) {
- GetWorkBudget adjustment = desired.subtract(remaining);
- getWorkBudgetSpender.adjustBudget(adjustment);
- }
- }
+ GetWorkBudget budgetPerStream = computeDesiredBudgets(budgetSpenders,
getWorkBudget);
+ budgetSpenders.forEach(getWorkBudgetSpender ->
getWorkBudgetSpender.setBudget(budgetPerStream));
}
- private <T extends GetWorkBudgetSpender> ImmutableMap<T, GetWorkBudget>
computeDesiredBudgets(
+ private <T extends GetWorkBudgetSpender> GetWorkBudget computeDesiredBudgets(
Review Comment:
nit: maybe name computeDesiredPerStreamBudget? or just inline?
budgets makes it sound like it is computing multiple.
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java:
##########
@@ -267,25 +277,100 @@ protected void startThrottleTimer() {
}
@Override
- public void adjustBudget(long itemsDelta, long bytesDelta) {
- GetWorkBudget adjustment =
- nextBudgetAdjustment
- // Get the current value, and reset the nextBudgetAdjustment. This
will be set again
- // when adjustBudget is called.
- .getAndUpdate(unused -> GetWorkBudget.noBudget())
- .apply(itemsDelta, bytesDelta);
- sendRequestExtension(adjustment);
+ public void setBudget(long newItems, long newBytes) {
+ GetWorkBudget currentMaxGetWorkBudget =
+ maxGetWorkBudget.updateAndGet(
+ ignored ->
GetWorkBudget.builder().setItems(newItems).setBytes(newBytes).build());
+ GetWorkBudget extension =
budgetTracker.computeBudgetExtension(currentMaxGetWorkBudget);
+ maybeSendRequestExtension(extension);
}
- @Override
- public GetWorkBudget remainingBudget() {
- // Snapshot the current budgets.
- GetWorkBudget currentPendingResponseBudget = pendingResponseBudget.get();
- GetWorkBudget currentNextBudgetAdjustment = nextBudgetAdjustment.get();
- GetWorkBudget currentInflightBudget = inFlightBudget.get();
-
- return currentPendingResponseBudget
- .apply(currentNextBudgetAdjustment)
- .apply(currentInflightBudget);
+ /**
+ * Tracks sent and received GetWorkBudget and uses this information to
generate request
+ * extensions.
+ */
+ @AutoValue
+ abstract static class GetWorkBudgetTracker {
+
+ private static GetWorkBudgetTracker create() {
+ return new AutoValue_GrpcDirectGetWorkStream_GetWorkBudgetTracker(
+ new AtomicLong(), new AtomicLong(), new AtomicLong(), new
AtomicLong());
+ }
+
+ abstract AtomicLong itemsRequested();
Review Comment:
can the members be changed to just raw longs/objects? The accessors just
need to be synchronized as well.
Seems like this could be easier without autovalue since we don't need the
accessors eather.
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java:
##########
@@ -231,7 +236,7 @@ public void appendSpecificHtml(PrintWriter writer) {
+ "total budget received: %s,"
+ "last sent request: %s. ",
workItemAssemblers.size(),
- maxGetWorkBudget.get(),
+ budgetTracker.maxGetWorkBudget().get(),
Review Comment:
could move html generation into budgettracker and not need all the
accessors. If we change how the tracker works in the future we might want to
show more too.
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java:
##########
@@ -41,6 +41,14 @@
public abstract class WindmillEndpoints {
private static final Logger LOG =
LoggerFactory.getLogger(WindmillEndpoints.class);
+ public static WindmillEndpoints none() {
+ return WindmillEndpoints.builder()
+ .setVersion(Long.MAX_VALUE)
Review Comment:
min seems safer. Otherwise if somehow none() was observed the logic to
ensure version is increasing mean's we'd never process another endpoint set
##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java:
##########
@@ -280,29 +300,30 @@ private synchronized void
consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
}
/** Close the streams that are no longer valid asynchronously. */
- @SuppressWarnings("FutureReturnValueIgnored")
- private void closeStaleStreams(WindmillEndpoints newWindmillEndpoints) {
+ private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) {
StreamingEngineBackends currentBackends = backends.get();
- ImmutableMap<Endpoint, WindmillStreamSender> currentWindmillStreams =
- currentBackends.windmillStreams();
- currentWindmillStreams.entrySet().stream()
+ currentBackends.windmillStreams().entrySet().stream()
.filter(
connectionAndStream ->
!newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey()))
.forEach(
- entry ->
- CompletableFuture.runAsync(
- () -> closeStreamSender(entry.getKey(), entry.getValue()),
- windmillStreamManager));
+ entry -> {
+ CompletableFuture<Void> ignored =
+ CompletableFuture.runAsync(
+ () -> closeStreamSender(entry.getKey(),
entry.getValue()),
+ windmillStreamManager);
+ });
Set<Endpoint> newGlobalDataEndpoints =
new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values());
currentBackends.globalDataStreams().values().stream()
.filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint()))
.forEach(
- sender ->
- CompletableFuture.runAsync(
- () -> closeStreamSender(sender.endpoint(), sender),
windmillStreamManager));
+ sender -> {
+ CompletableFuture<Void> ignored =
Review Comment:
ditto
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java:
##########
@@ -192,17 +82,17 @@ public void
testDistributeBudget_distributesBudgetEvenlyIfPossible() {
streams.forEach(
Review Comment:
just skip the math in the test and inline the right values?
The math is just copying what we have in the impl, if there is some bug in
the impl hard coding the values at least is a sanity check.
ditto below.
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ WorkItemScheduler workItemScheduler) {
+ return createGetWorkStream(
+ testStub, initialGetWorkBudget, new ThrottleTimer(),
workItemScheduler);
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer) {
+ return createGetWorkStream(
+ testStub,
+ initialGetWorkBudget,
+ throttleTimer,
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) ->
{});
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
Review Comment:
assertTrue for awaits that should succeed. If they fail the rest will fail
and maybe be confusing.
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/budget/EvenGetWorkBudgetDistributorTest.java:
##########
@@ -40,147 +38,39 @@ public class EvenGetWorkBudgetDistributorTest {
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
@Rule public transient Timeout globalTimeout = Timeout.seconds(600);
- private static GetWorkBudgetDistributor
createBudgetDistributor(GetWorkBudget activeWorkBudget) {
- return GetWorkBudgetDistributors.distributeEvenly(() -> activeWorkBudget);
- }
+ private static GetWorkBudgetSpender createGetWorkBudgetOwner() {
+ // Lambdas are final and cannot be spied.
+ return spy(
+ new GetWorkBudgetSpender() {
- private static GetWorkBudgetDistributor createBudgetDistributor(long
activeWorkItemsAndBytes) {
- return createBudgetDistributor(
- GetWorkBudget.builder()
- .setItems(activeWorkItemsAndBytes)
- .setBytes(activeWorkItemsAndBytes)
- .build());
+ @Override
+ public void setBudget(long items, long bytes) {}
+ });
}
@Test
public void testDistributeBudget_doesNothingWhenPassedInStreamsEmpty() {
- createBudgetDistributor(1L)
+ GetWorkBudgetDistributors.distributeEvenly()
.distributeBudget(
ImmutableList.of(),
GetWorkBudget.builder().setItems(10L).setBytes(10L).build());
}
@Test
public void testDistributeBudget_doesNothingWithNoBudget() {
- GetWorkBudgetSpender getWorkBudgetSpender =
-
spy(createGetWorkBudgetOwnerWithRemainingBudgetOf(GetWorkBudget.noBudget()));
- createBudgetDistributor(1L)
+ GetWorkBudgetSpender getWorkBudgetSpender =
spy(createGetWorkBudgetOwner());
Review Comment:
remove spy here? already done in the helper method
ditto for below
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
Review Comment:
ditto, tests can just use full 4 params and make a throttle timer
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ WorkItemScheduler workItemScheduler) {
+ return createGetWorkStream(
+ testStub, initialGetWorkBudget, new ThrottleTimer(),
workItemScheduler);
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer) {
+ return createGetWorkStream(
+ testStub,
+ initialGetWorkBudget,
+ throttleTimer,
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) ->
{});
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ // Header and extension.
+ assertThat(requestObserver.sent()).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+ .isEqualTo(extension(newBudget));
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream = createGetWorkStream(testStub, initialBudget);
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+ stream.setBudget(newBudget);
+ GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Header and extension.
+ assertThat(requests).hasSize(expectedRequests);
+
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+ }
+
+ @Test
+ public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testSetBudget_doesNothingIfStreamShutdown() throws
InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ stream.shutdown();
+ stream.setBudget(
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(1);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws
InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+ initialBudget,
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requests).getRequestExtension())
+ .isEqualTo(
+ extension(
+ GetWorkBudget.builder()
+ .setItems(1)
+ .setBytes(initialBudget.bytes() - inFlightBytes)
+ .build()));
+ }
+
+ @Test
+ public void
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(),
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testOnResponse_stopsThrottling() {
+ ThrottleTimer throttleTimer = new ThrottleTimer();
+ TestGetWorkRequestObserver requestObserver =
+ new TestGetWorkRequestObserver(new CountDownLatch(1));
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(),
throttleTimer);
+ stream.startThrottleTimer();
Review Comment:
assertTrue(throttleTimer.throttled())
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ WorkItemScheduler workItemScheduler) {
+ return createGetWorkStream(
+ testStub, initialGetWorkBudget, new ThrottleTimer(),
workItemScheduler);
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer) {
+ return createGetWorkStream(
+ testStub,
+ initialGetWorkBudget,
+ throttleTimer,
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) ->
{});
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ // Header and extension.
+ assertThat(requestObserver.sent()).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+ .isEqualTo(extension(newBudget));
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream = createGetWorkStream(testStub, initialBudget);
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+ stream.setBudget(newBudget);
+ GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Header and extension.
+ assertThat(requests).hasSize(expectedRequests);
+
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+ }
+
+ @Test
+ public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testSetBudget_doesNothingIfStreamShutdown() throws
InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ stream.shutdown();
+ stream.setBudget(
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(1);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws
InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+ initialBudget,
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requests).getRequestExtension())
+ .isEqualTo(
+ extension(
+ GetWorkBudget.builder()
+ .setItems(1)
+ .setBytes(initialBudget.bytes() - inFlightBytes)
+ .build()));
+ }
+
+ @Test
+ public void
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(),
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testOnResponse_stopsThrottling() {
+ ThrottleTimer throttleTimer = new ThrottleTimer();
+ TestGetWorkRequestObserver requestObserver =
+ new TestGetWorkRequestObserver(new CountDownLatch(1));
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(),
throttleTimer);
+ stream.startThrottleTimer();
+
testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance());
Review Comment:
does this run inline? otherwise it seems like it coudl be racy below that
unthrottling happens?
##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java:
##########
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.client.grpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.Mockito.mock;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
+import
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
+import
org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer;
+import org.apache.beam.runners.dataflow.worker.windmill.work.WorkItemScheduler;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
+import
org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
+import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder;
+import
org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.Timeout;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GrpcDirectGetWorkStreamTest {
+ private static final Windmill.JobHeader TEST_JOB_HEADER =
+ Windmill.JobHeader.newBuilder()
+ .setJobId("test_job")
+ .setWorkerId("test_worker")
+ .setProjectId("test_project")
+ .build();
+ private static final String FAKE_SERVER_NAME = "Fake server for
GrpcDirectGetWorkStreamTest";
+ @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private final MutableHandlerRegistry serviceRegistry = new
MutableHandlerRegistry();
+ @Rule public transient Timeout globalTimeout = Timeout.seconds(600);
+ private ManagedChannel inProcessChannel;
+ private GrpcDirectGetWorkStream stream;
+
+ private static Windmill.StreamingGetWorkRequestExtension
extension(GetWorkBudget budget) {
+ return Windmill.StreamingGetWorkRequestExtension.newBuilder()
+ .setMaxItems(budget.items())
+ .setMaxBytes(budget.bytes())
+ .build();
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ Server server =
+ InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+ .fallbackHandlerRegistry(serviceRegistry)
+ .directExecutor()
+ .build()
+ .start();
+
+ inProcessChannel =
+ grpcCleanup.register(
+
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+ grpcCleanup.register(server);
+ grpcCleanup.register(inProcessChannel);
+ }
+
+ @After
+ public void cleanUp() {
+ inProcessChannel.shutdownNow();
+ checkNotNull(stream).shutdown();
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub, GetWorkBudget initialGetWorkBudget) {
+ return createGetWorkStream(testStub, initialGetWorkBudget, new
ThrottleTimer());
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ WorkItemScheduler workItemScheduler) {
+ return createGetWorkStream(
+ testStub, initialGetWorkBudget, new ThrottleTimer(),
workItemScheduler);
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer) {
+ return createGetWorkStream(
+ testStub,
+ initialGetWorkBudget,
+ throttleTimer,
+ (workItem, watermarks, processingContext, getWorkStreamLatencies) ->
{});
+ }
+
+ private GrpcDirectGetWorkStream createGetWorkStream(
+ GetWorkStreamTestStub testStub,
+ GetWorkBudget initialGetWorkBudget,
+ ThrottleTimer throttleTimer,
+ WorkItemScheduler workItemScheduler) {
+ serviceRegistry.addService(testStub);
+ return (GrpcDirectGetWorkStream)
+ GrpcWindmillStreamFactory.of(TEST_JOB_HEADER)
+ .build()
+ .createDirectGetWorkStream(
+ WindmillConnection.builder()
+
.setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel))
+ .build(),
+ Windmill.GetWorkRequest.newBuilder()
+ .setClientId(TEST_JOB_HEADER.getClientId())
+ .setJobId(TEST_JOB_HEADER.getJobId())
+ .setProjectId(TEST_JOB_HEADER.getProjectId())
+ .setWorkerId(TEST_JOB_HEADER.getWorkerId())
+ .setMaxItems(initialGetWorkBudget.items())
+ .setMaxBytes(initialGetWorkBudget.bytes())
+ .build(),
+ throttleTimer,
+ mock(HeartbeatSender.class),
+ mock(GetDataClient.class),
+ mock(WorkCommitter.class),
+ workItemScheduler);
+ }
+
+ private Windmill.StreamingGetWorkResponseChunk
createResponse(Windmill.WorkItem workItem) {
+ return Windmill.StreamingGetWorkResponseChunk.newBuilder()
+ .setStreamId(1L)
+ .setComputationMetadata(
+ Windmill.ComputationWorkItemMetadata.newBuilder()
+ .setComputationId("compId")
+ .setInputDataWatermark(1L)
+ .setDependentRealtimeInputWatermark(1L)
+ .build())
+ .setSerializedWorkItem(workItem.toByteString())
+ .setRemainingBytesForWorkItem(0)
+ .build();
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_noExistingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream.setBudget(newBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ // Header and extension.
+ assertThat(requestObserver.sent()).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requestObserver.sent()).getRequestExtension())
+ .isEqualTo(extension(newBudget));
+ }
+
+ @Test
+ public void testSetBudget_computesAndSendsCorrectExtension_existingBudget()
+ throws InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(10).setBytes(10).build();
+ stream = createGetWorkStream(testStub, initialBudget);
+ GetWorkBudget newBudget =
GetWorkBudget.builder().setItems(100).setBytes(100).build();
+ stream.setBudget(newBudget);
+ GetWorkBudget diff = newBudget.subtract(initialBudget);
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Header and extension.
+ assertThat(requests).hasSize(expectedRequests);
+
assertThat(Iterables.getLast(requests).getRequestExtension()).isEqualTo(extension(diff));
+ }
+
+ @Test
+ public void testSetBudget_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
stream.setBudget(GetWorkBudget.builder().setItems(10).setBytes(10).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testSetBudget_doesNothingIfStreamShutdown() throws
InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget());
+ stream.shutdown();
+ stream.setBudget(
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build());
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(1);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testConsumedWorkItem_computesAndSendsCorrectExtension() throws
InterruptedException {
+ int expectedRequests = 2;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ GetWorkBudget initialBudget =
GetWorkBudget.builder().setItems(1).setBytes(100).build();
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+ initialBudget,
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+ long inFlightBytes = initialBudget.bytes() - workItem.getSerializedSize();
+
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getLast(requests).getRequestExtension())
+ .isEqualTo(
+ extension(
+ GetWorkBudget.builder()
+ .setItems(1)
+ .setBytes(initialBudget.bytes() - inFlightBytes)
+ .build()));
+ }
+
+ @Test
+ public void
testConsumedWorkItem_doesNotSendExtensionIfOutstandingBudgetHigh()
+ throws InterruptedException {
+ int expectedRequests = 1;
+ CountDownLatch waitForRequests = new CountDownLatch(expectedRequests);
+ TestGetWorkRequestObserver requestObserver = new
TestGetWorkRequestObserver(waitForRequests);
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ Set<Windmill.WorkItem> scheduledWorkItems = new HashSet<>();
+ stream =
+ createGetWorkStream(
+ testStub,
+
GetWorkBudget.builder().setItems(Long.MAX_VALUE).setBytes(Long.MAX_VALUE).build(),
+ (work, watermarks, processingContext, getWorkStreamLatencies) -> {
+ scheduledWorkItems.add(work);
+ });
+ Windmill.WorkItem workItem =
+ Windmill.WorkItem.newBuilder()
+ .setKey(ByteString.copyFromUtf8("somewhat_long_key"))
+ .setWorkToken(1L)
+ .setShardingKey(1L)
+ .setCacheToken(1L)
+ .build();
+
+ testStub.injectResponse(createResponse(workItem));
+
+ waitForRequests.await(5, TimeUnit.SECONDS);
+
+ assertThat(scheduledWorkItems).containsExactly(workItem);
+ List<Windmill.StreamingGetWorkRequest> requests = requestObserver.sent();
+
+ // Assert that the extension was never sent, only the header.
+ assertThat(requests).hasSize(expectedRequests);
+ assertThat(Iterables.getOnlyElement(requests).getRequest())
+ .isInstanceOf(Windmill.GetWorkRequest.class);
+ }
+
+ @Test
+ public void testOnResponse_stopsThrottling() {
+ ThrottleTimer throttleTimer = new ThrottleTimer();
+ TestGetWorkRequestObserver requestObserver =
+ new TestGetWorkRequestObserver(new CountDownLatch(1));
+ GetWorkStreamTestStub testStub = new
GetWorkStreamTestStub(requestObserver);
+ stream = createGetWorkStream(testStub, GetWorkBudget.noBudget(),
throttleTimer);
+ stream.startThrottleTimer();
+
testStub.injectResponse(Windmill.StreamingGetWorkResponseChunk.getDefaultInstance());
+ assertFalse(throttleTimer.throttled());
+ }
+
+ private static class GetWorkStreamTestStub
+ extends
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+ private final TestGetWorkRequestObserver requestObserver;
Review Comment:
might as well mark volatile to prevent races
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]