scwhittle commented on code in PR #27767:
URL: https://github.com/apache/beam/pull/27767#discussion_r1295579721


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/ForwardingClientResponseObserver.java:
##########
@@ -27,23 +27,23 @@
  * <p>Used to wrap existing {@link StreamObserver}s to be able to install an 
{@link
  * ClientCallStreamObserver#setOnReadyHandler(Runnable) onReadyHandler}.
  *
- * <p>This is as thread-safe as the undering stream observer that is being 
wrapped.
+ * <p>This is as thread-safe as the underlying stream observer that is being 
wrapped.
  */
-final class ForwardingClientResponseObserver<ReqT, RespT>
-    implements ClientResponseObserver<RespT, ReqT> {
+final class ForwardingClientResponseObserver<ResponseT, RequestT>

Review Comment:
   Can you rebase? I think a lot of this has already been merged with your 
previous PR



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java:
##########
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap;
+
+import com.google.auto.value.AutoValue;
+import java.util.Map;
+import java.util.Optional;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+
+/** Value class for holding endpoints used for communicating with Windmill 
service. */
+@AutoValue
+public abstract class WindmillEndpoints {
+  public static WindmillEndpoints from(
+      Windmill.WorkerMetadataResponse workerMetadataResponseProto) {
+    ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers =
+        
workerMetadataResponseProto.getGlobalDataEndpointsMap().entrySet().stream()
+            .collect(
+                toImmutableMap(
+                    Map.Entry::getKey, // global data key
+                    endpoint -> 
WindmillEndpoints.Endpoint.from(endpoint.getValue())));
+
+    ImmutableList<WindmillEndpoints.Endpoint> windmillServers =
+        workerMetadataResponseProto.getDirectPathEndpointsList().stream()
+            .map(WindmillEndpoints.Endpoint::from)
+            .collect(toImmutableList());
+
+    return WindmillEndpoints.builder()
+        .setGlobalDataServers(globalDataServers)
+        .setWindmillServers(windmillServers)
+        .build();
+  }
+
+  public static WindmillEndpoints.Builder builder() {
+    return new AutoValue_WindmillEndpoints.Builder();
+  }
+
+  /**
+   * Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns 
a map where the key
+   * is a global data tag and the value is the endpoint where the data 
associated with the global
+   * data tag resides.
+   *
+   * @see <a 
href="https://beam.apache.org/documentation/programming-guide/#side-inputs";>Beam
 Side
+   *     Inputs</a>
+   */
+  public abstract ImmutableMap<String, Endpoint> globalDataServers();
+
+  /**
+   * Used by GetWork/GetData/CommitWork calls to send, receive, and commit 
work directly to/from
+   * Windmill servers. Returns a list of endpoints used to communicate with 
the corresponding
+   * Windmill servers.
+   */
+  public abstract ImmutableList<Endpoint> windmillServers();
+
+  @AutoValue

Review Comment:
   just checking that Autovalue automatically creates equals/hash methods



##########
runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto:
##########
@@ -758,7 +758,8 @@ message WorkerMetadataResponse {
   // CommitWorkStream. Each response on this stream replaces the previous, and
   // connections to endpoints that are no longer present should be closed.
   message Endpoint {
-    optional string endpoint = 1;
+    optional string direct_endpoint = 1;
+    optional string worker_token = 2;

Review Comment:
   I think we want to remove direct_path_endpoints and just use work_endpoints



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStreamTest.java:
##########
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.grpcclient;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class GrpcGetWorkerMetadataStreamTest {
+  private static final List<WorkerMetadataResponse.Endpoint> 
DIRECT_PATH_ENDPOINTS =
+      Lists.newArrayList(
+          
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("direct_path").build());
+  private static final Map<String, WorkerMetadataResponse.Endpoint> 
GLOBAL_DATA_ENDPOINTS =
+      Maps.newHashMap();
+  private static final JobHeader TEST_JOB_HEADER =
+      JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcGetWorkerMetadataStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  private final Set<AbstractWindmillStream<?, ?>> streamRegistry = new 
HashSet<>();
+  private ManagedChannel inProcessChannel;
+  private GrpcGetWorkerMetadataStream stream;
+
+  private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream(
+      GetWorkerMetadataTestStub getWorkerMetadataTestStub,
+      AtomicLong metadataVersion,
+      Consumer<WindmillEndpoints> endpointsConsumer) {
+    serviceRegistry.addService(getWorkerMetadataTestStub);
+    return GrpcGetWorkerMetadataStream.forTesting(
+        responseObserver ->
+            CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)
+                .getWorkerMetadataStream(responseObserver),
+        metadataVersion,
+        TEST_JOB_HEADER,
+        streamRegistry,
+        endpointsConsumer);
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+    GLOBAL_DATA_ENDPOINTS.put(
+        "global_data",
+        
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("global_data").build());
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+  }
+
+  @Test
+  public void testGetWorkerMetadata() {
+    WorkerMetadataResponse mockResponse =
+        WorkerMetadataResponse.newBuilder()
+            .setMetadataVersion(1)
+            .addAllDirectPathEndpoints(DIRECT_PATH_ENDPOINTS)
+            .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+            .build();
+    TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
+        new TestWindmillEndpointsConsumer();
+    TestGetWorkMetadataRequestObserver requestObserver =
+        new TestGetWorkMetadataRequestObserver(mockResponse);
+    GetWorkerMetadataTestStub testStub = new 
GetWorkerMetadataTestStub(requestObserver);
+    AtomicLong metadataVersion = new AtomicLong(0);
+    stream = getWorkerMetadataTestStream(testStub, metadataVersion, 
testWindmillEndpointsConsumer);
+    assertEquals(
+        GLOBAL_DATA_ENDPOINTS.size(), 
testWindmillEndpointsConsumer.globalDataEndpoints.size());
+    testWindmillEndpointsConsumer.globalDataEndpoints.forEach(
+        (key, value) -> assertTrue(GLOBAL_DATA_ENDPOINTS.containsKey(key)));
+    assertEquals(
+        DIRECT_PATH_ENDPOINTS.size(), 
testWindmillEndpointsConsumer.windmillEndpoints.size());
+    assertTrue(
+        testWindmillEndpointsConsumer.windmillEndpoints.containsAll(
+            DIRECT_PATH_ENDPOINTS.stream()
+                .map(WindmillEndpoints.Endpoint::from)
+                .collect(Collectors.toList())));
+    assertEquals(mockResponse.getMetadataVersion(), metadataVersion.get());
+  }
+
+  @Test
+  public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() {
+    WorkerMetadataResponse initialResponse =
+        WorkerMetadataResponse.newBuilder()
+            .setMetadataVersion(1)
+            .addAllDirectPathEndpoints(DIRECT_PATH_ENDPOINTS)
+            .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+            .build();
+    TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
+        Mockito.spy(new TestWindmillEndpointsConsumer());
+    TestGetWorkMetadataRequestObserver requestObserver =
+        new TestGetWorkMetadataRequestObserver(initialResponse);
+    GetWorkerMetadataTestStub testStub = new 
GetWorkerMetadataTestStub(requestObserver);
+    AtomicLong metadataVersion = new AtomicLong(0);
+    stream = getWorkerMetadataTestStream(testStub, metadataVersion, 
testWindmillEndpointsConsumer);
+
+    List<WorkerMetadataResponse.Endpoint> newDirectPathEndpoints =
+        Lists.newArrayList(
+            WorkerMetadataResponse.Endpoint.newBuilder()
+                .setDirectEndpoint("newWindmillEndpoint")
+                .build());
+    Map<String, WorkerMetadataResponse.Endpoint> newGlobalDataEndpoints = new 
HashMap<>();
+    newGlobalDataEndpoints.put(
+        "new_global_data",
+        
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("newGlobalData").build());
+
+    WorkerMetadataResponse newWorkMetadataResponse =
+        WorkerMetadataResponse.newBuilder()
+            .setMetadataVersion(initialResponse.getMetadataVersion() + 1)
+            .addAllDirectPathEndpoints(newDirectPathEndpoints)
+            .putAllGlobalDataEndpoints(newGlobalDataEndpoints)
+            .build();
+
+    testStub.injectWorkerMetadata(newWorkMetadataResponse);
+
+    assertEquals(
+        newGlobalDataEndpoints.size(), 
testWindmillEndpointsConsumer.globalDataEndpoints.size());
+    testWindmillEndpointsConsumer.globalDataEndpoints.forEach(
+        (key, value) -> assertTrue(newGlobalDataEndpoints.containsKey(key)));
+    assertEquals(
+        newDirectPathEndpoints.size(), 
testWindmillEndpointsConsumer.windmillEndpoints.size());
+    assertTrue(
+        testWindmillEndpointsConsumer.windmillEndpoints.containsAll(
+            newDirectPathEndpoints.stream()
+                .map(WindmillEndpoints.Endpoint::from)
+                .collect(Collectors.toList())));
+    assertEquals(newWorkMetadataResponse.getMetadataVersion(), 
metadataVersion.get());
+  }
+
+  @Test
+  public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() {
+    WorkerMetadataResponse freshEndpoints =
+        WorkerMetadataResponse.newBuilder()
+            .setMetadataVersion(2)
+            .addAllDirectPathEndpoints(DIRECT_PATH_ENDPOINTS)
+            .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+            .build();
+    TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
+        Mockito.spy(new TestWindmillEndpointsConsumer());
+    TestGetWorkMetadataRequestObserver requestObserver =
+        new TestGetWorkMetadataRequestObserver(freshEndpoints);
+    GetWorkerMetadataTestStub testStub = new 
GetWorkerMetadataTestStub(requestObserver);
+    AtomicLong metadataVersion = new AtomicLong(0);
+    stream = getWorkerMetadataTestStream(testStub, metadataVersion, 
testWindmillEndpointsConsumer);
+    List<WorkerMetadataResponse.Endpoint> staleDirectPathEndpoints =
+        Lists.newArrayList(
+            WorkerMetadataResponse.Endpoint.newBuilder()
+                .setDirectEndpoint("staleWindmillEndpoint")
+                .build());
+    Map<String, WorkerMetadataResponse.Endpoint> staleGlobalDataEndpoints = 
new HashMap<>();
+    staleGlobalDataEndpoints.put(
+        "stale_global_data",
+        
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("staleGlobalData").build());
+
+    testStub.injectWorkerMetadata(
+        WorkerMetadataResponse.newBuilder()
+            .setMetadataVersion(1)
+            .addAllDirectPathEndpoints(staleDirectPathEndpoints)
+            .putAllGlobalDataEndpoints(staleGlobalDataEndpoints)
+            .build());
+
+    // Should have ignored the stale update and only used initial.
+    
verify(testWindmillEndpointsConsumer).accept(WindmillEndpoints.from(freshEndpoints));
+    verifyNoMoreInteractions(testWindmillEndpointsConsumer);
+    assertEquals(freshEndpoints.getMetadataVersion(), metadataVersion.get());
+  }
+
+  @Test
+  public void 
testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() {
+    TestGetWorkMetadataRequestObserver requestObserver =
+        new TestGetWorkMetadataRequestObserver(
+            WorkerMetadataResponse.newBuilder()
+                .setMetadataVersion(1)
+                .addAllDirectPathEndpoints(DIRECT_PATH_ENDPOINTS)
+                .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+                .build());
+    GetWorkerMetadataTestStub testStub = new 
GetWorkerMetadataTestStub(requestObserver);
+    stream =
+        getWorkerMetadataTestStream(
+            testStub, new AtomicLong(), new TestWindmillEndpointsConsumer());
+
+    assertTrue(streamRegistry.contains(stream));
+    stream.close();
+    assertFalse(streamRegistry.contains(stream));
+  }
+
+  @Test
+  public void testSendHealthCheck() {
+    TestGetWorkMetadataRequestObserver requestObserver =
+        Mockito.spy(
+            new TestGetWorkMetadataRequestObserver(
+                WorkerMetadataResponse.newBuilder()
+                    .setMetadataVersion(1)
+                    .addAllDirectPathEndpoints(DIRECT_PATH_ENDPOINTS)
+                    .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+                    .build()));
+    GetWorkerMetadataTestStub testStub = new 
GetWorkerMetadataTestStub(requestObserver);
+    stream =
+        getWorkerMetadataTestStream(
+            testStub, new AtomicLong(), new TestWindmillEndpointsConsumer());
+    stream.sendHealthCheck();
+
+    verify(requestObserver).onNext(WorkerMetadataRequest.getDefaultInstance());
+  }
+
+  private static class GetWorkerMetadataTestStub
+      extends 
CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
+    private final TestGetWorkMetadataRequestObserver requestObserver;
+    private @Nullable StreamObserver<WorkerMetadataResponse> responseObserver;
+
+    private GetWorkerMetadataTestStub(TestGetWorkMetadataRequestObserver 
requestObserver) {
+      this.requestObserver = requestObserver;
+    }
+
+    @Override
+    public StreamObserver<WorkerMetadataRequest> getWorkerMetadataStream(
+        StreamObserver<WorkerMetadataResponse> responseObserver) {
+      if (this.responseObserver == null) {
+        this.responseObserver = responseObserver;
+        requestObserver.responseObserver = this.responseObserver;
+      }
+
+      return requestObserver;
+    }
+
+    private void injectWorkerMetadata(WorkerMetadataResponse response) {
+      if (responseObserver != null) {
+        responseObserver.onNext(response);
+      }
+    }
+  }
+
+  private static class TestGetWorkMetadataRequestObserver
+      implements StreamObserver<WorkerMetadataRequest> {
+    private final WorkerMetadataResponse response;
+    private @Nullable StreamObserver<WorkerMetadataResponse> responseObserver;
+
+    private TestGetWorkMetadataRequestObserver(WorkerMetadataResponse 
response) {
+      this.response = response;
+    }
+
+    @Override
+    public void onNext(WorkerMetadataRequest workerMetadataRequest) {
+      responseObserver.onNext(response);

Review Comment:
   what about removing this way to send responses and just having tests 
explicitly call injectWorkerMetadata?
   I think it will be easier to see what tests are doing



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStreamTest.java:
##########
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.grpcclient;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream;
+import 
org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannel;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.Server;
+import 
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessChannelBuilder;
+import 
org.apache.beam.vendor.grpc.v1p54p0.io.grpc.inprocess.InProcessServerBuilder;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.testing.GrpcCleanupRule;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.util.MutableHandlerRegistry;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+@RunWith(JUnit4.class)
+public class GrpcGetWorkerMetadataStreamTest {
+  private static final List<WorkerMetadataResponse.Endpoint> 
DIRECT_PATH_ENDPOINTS =
+      Lists.newArrayList(
+          
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("direct_path").build());
+  private static final Map<String, WorkerMetadataResponse.Endpoint> 
GLOBAL_DATA_ENDPOINTS =
+      Maps.newHashMap();
+  private static final JobHeader TEST_JOB_HEADER =
+      JobHeader.newBuilder()
+          .setJobId("test_job")
+          .setWorkerId("test_worker")
+          .setProjectId("test_project")
+          .build();
+  private static final String FAKE_SERVER_NAME = "Fake server for 
GrpcGetWorkerMetadataStreamTest";
+  @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private final MutableHandlerRegistry serviceRegistry = new 
MutableHandlerRegistry();
+  private final Set<AbstractWindmillStream<?, ?>> streamRegistry = new 
HashSet<>();
+  private ManagedChannel inProcessChannel;
+  private GrpcGetWorkerMetadataStream stream;
+
+  private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream(
+      GetWorkerMetadataTestStub getWorkerMetadataTestStub,
+      AtomicLong metadataVersion,
+      Consumer<WindmillEndpoints> endpointsConsumer) {
+    serviceRegistry.addService(getWorkerMetadataTestStub);
+    return GrpcGetWorkerMetadataStream.forTesting(
+        responseObserver ->
+            CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)
+                .getWorkerMetadataStream(responseObserver),
+        metadataVersion,
+        TEST_JOB_HEADER,
+        streamRegistry,
+        endpointsConsumer);
+  }
+
+  @Before
+  public void setUp() throws IOException {
+    Server server =
+        InProcessServerBuilder.forName(FAKE_SERVER_NAME)
+            .fallbackHandlerRegistry(serviceRegistry)
+            .directExecutor()
+            .build()
+            .start();
+    inProcessChannel =
+        grpcCleanup.register(
+            
InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build());
+    grpcCleanup.register(server);
+    grpcCleanup.register(inProcessChannel);
+    GLOBAL_DATA_ENDPOINTS.put(
+        "global_data",
+        
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("global_data").build());
+  }
+
+  @After
+  public void cleanUp() {
+    inProcessChannel.shutdownNow();
+  }
+
+  @Test
+  public void testGetWorkerMetadata() {
+    WorkerMetadataResponse mockResponse =
+        WorkerMetadataResponse.newBuilder()
+            .setMetadataVersion(1)
+            .addAllDirectPathEndpoints(DIRECT_PATH_ENDPOINTS)
+            .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
+            .build();
+    TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
+        new TestWindmillEndpointsConsumer();
+    TestGetWorkMetadataRequestObserver requestObserver =
+        new TestGetWorkMetadataRequestObserver(mockResponse);
+    GetWorkerMetadataTestStub testStub = new 
GetWorkerMetadataTestStub(requestObserver);
+    AtomicLong metadataVersion = new AtomicLong(0);

Review Comment:
   nit: use some non-0 version, 0 might be correct if just default initialized 
on accident instead of propagated for example.
   
   (sorry if some of these are duplicated, I was putting comments on some other 
snapshot I think, not sure where they'll show up).



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/grpcclient/GrpcGetWorkerMetadataStream.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.grpcclient;
+
+import java.io.PrintWriter;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import org.apache.beam.runners.dataflow.worker.windmill.AbstractWindmillStream;
+import org.apache.beam.runners.dataflow.worker.windmill.StreamObserverFactory;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest;
+import 
org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse;
+import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints;
+import 
org.apache.beam.runners.dataflow.worker.windmill.WindmillStream.GetWorkerMetadataStream;
+import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+final class GrpcGetWorkerMetadataStream
+    extends AbstractWindmillStream<WorkerMetadataRequest, 
WorkerMetadataResponse>
+    implements GetWorkerMetadataStream {
+  private static final Logger LOG = 
LoggerFactory.getLogger(GrpcGetWorkerMetadataStream.class);
+  private static final WorkerMetadataRequest HEALTH_CHECK_REQUEST =
+      WorkerMetadataRequest.getDefaultInstance();
+
+  private final AtomicLong metadataVersion;
+  private final WorkerMetadataRequest workerMetadataRequest;
+  private final ThrottleTimer getWorkerMetadataThrottleTimer;
+  private final Consumer<WindmillEndpoints> serverMappingUpdater;
+
+  private GrpcGetWorkerMetadataStream(
+      Function<StreamObserver<WorkerMetadataResponse>, 
StreamObserver<WorkerMetadataRequest>>
+          startGetWorkerMetadataRpcFn,
+      BackOff backoff,
+      StreamObserverFactory streamObserverFactory,
+      Set<AbstractWindmillStream<?, ?>> streamRegistry,
+      int logEveryNStreamFailures,
+      JobHeader jobHeader,
+      AtomicLong metadataVersion,
+      ThrottleTimer getWorkerMetadataThrottleTimer,
+      Consumer<WindmillEndpoints> serverMappingUpdater) {
+    super(
+        startGetWorkerMetadataRpcFn,
+        backoff,
+        streamObserverFactory,
+        streamRegistry,
+        logEveryNStreamFailures);
+    this.workerMetadataRequest = 
WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build();
+    this.metadataVersion = metadataVersion;
+    this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer;
+    this.serverMappingUpdater = serverMappingUpdater;
+  }
+
+  public static GrpcGetWorkerMetadataStream create(
+      Function<StreamObserver<WorkerMetadataResponse>, 
StreamObserver<WorkerMetadataRequest>>
+          startGetWorkerMetadataRpcFn,
+      BackOff backoff,
+      StreamObserverFactory streamObserverFactory,
+      Set<AbstractWindmillStream<?, ?>> streamRegistry,
+      int logEveryNStreamFailures,
+      JobHeader jobHeader,
+      AtomicLong metadataVersion,

Review Comment:
   integer? Can keep the AtomicLong internal for non-testing case to ensure it 
isn't modified externally



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to