[REEF-245] Move Network test from REEF.Tests to REEF.Network.Tests project
This PR is add a new test project for Network unit tests and move Network unit test cases to the new project JIRA: [REEF-245](https://issues.apache.org/jira/browse/REEF-245) This closes #142 Project: http://git-wip-us.apache.org/repos/asf/incubator-reef/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-reef/commit/987d4f37 Tree: http://git-wip-us.apache.org/repos/asf/incubator-reef/tree/987d4f37 Diff: http://git-wip-us.apache.org/repos/asf/incubator-reef/diff/987d4f37 Branch: refs/heads/master Commit: 987d4f37d6b7be9ef367655bd27a2c67097d2aec Parents: f44a659 Author: Julia Wang <[email protected]> Authored: Thu Apr 9 18:23:20 2015 -0700 Committer: Markus Weimer <[email protected]> Committed: Mon Apr 13 09:20:09 2015 -0700 ---------------------------------------------------------------------- .../BlockingCollectionExtensionTests.cs | 74 ++ .../GroupCommunicationTests.cs | 742 +++++++++++++++++++ .../GroupCommunicationTreeTopologyTests.cs | 634 ++++++++++++++++ .../NamingService/NameServerTests.cs | 269 +++++++ .../NetworkService/NetworkServiceTests.cs | 210 ++++++ .../Org.Apache.REEF.Network.Tests.csproj | 90 +++ .../Properties/AssemblyInfo.cs | 55 ++ .../packages.config | 23 + .../Network/BlockingCollectionExtensionTests.cs | 74 -- .../Network/GroupCommunicationTests.cs | 742 ------------------- .../GroupCommunicationTreeTopologyTests.cs | 634 ---------------- .../Network/NameServerTests.cs | 269 ------- .../Network/NetworkServiceTests.cs | 210 ------ .../Org.Apache.REEF.Tests.csproj | 5 - lang/cs/Org.Apache.REEF.sln | Bin 20908 -> 42590 bytes 15 files changed, 2097 insertions(+), 1934 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/987d4f37/lang/cs/Org.Apache.REEF.Network.Tests/BlockingCollectionExtensionTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/BlockingCollectionExtensionTests.cs b/lang/cs/Org.Apache.REEF.Network.Tests/BlockingCollectionExtensionTests.cs new file mode 100644 index 0000000..aa4ebc6 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network.Tests/BlockingCollectionExtensionTests.cs @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Concurrent; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Org.Apache.REEF.Network.Utilities; + +namespace Org.Apache.REEF.Network.Tests +{ + [TestClass] + public class BlockingCollectionExtensionTests + { + [TestMethod] + public void TestCollectionContainsElement() + { + string item = "abc"; + BlockingCollection<string> collection = new BlockingCollection<string>(); + collection.Add(item); + + Assert.AreEqual(item, collection.Take(item)); + + // Check that item is no longer in collection + Assert.AreEqual(0, collection.Count); + } + + [TestMethod] + public void TestCollectionContainsElement2() + { + string item = "abc"; + BlockingCollection<string> collection = new BlockingCollection<string>(); + collection.Add("cat"); + collection.Add(item); + collection.Add("dog"); + + Assert.AreEqual(item, collection.Take(item)); + + // Remove remaining items, check that item is not there + Assert.AreNotEqual(item, collection.Take()); + Assert.AreNotEqual(item, collection.Take()); + Assert.AreEqual(0, collection.Count); + } + + [TestMethod] + [ExpectedException(typeof(InvalidOperationException))] + public void TestCollectionDoesNotContainsElement() + { + string item1 = "abc"; + string item2 = "def"; + + BlockingCollection<string> collection = new BlockingCollection<string>(); + collection.Add(item2); + + // Should throw InvalidOperationException since item1 is not in collection + collection.Take(item1); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/987d4f37/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs b/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs new file mode 100644 index 0000000..c0808d8 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs @@ -0,0 +1,742 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Reactive; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Org.Apache.REEF.Common.Tasks; +using Org.Apache.REEF.Network.Group.Codec; +using Org.Apache.REEF.Network.Group.Config; +using Org.Apache.REEF.Network.Group.Driver; +using Org.Apache.REEF.Network.Group.Driver.Impl; +using Org.Apache.REEF.Network.Group.Operators; +using Org.Apache.REEF.Network.Group.Operators.Impl; +using Org.Apache.REEF.Network.Group.Task; +using Org.Apache.REEF.Network.Group.Topology; +using Org.Apache.REEF.Network.Naming; +using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Tang.Formats; +using Org.Apache.REEF.Tang.Implementations.Configuration; +using Org.Apache.REEF.Tang.Implementations.Tang; +using Org.Apache.REEF.Tang.Interface; +using Org.Apache.REEF.Tang.Util; +using Org.Apache.REEF.Wake.Remote; +using Org.Apache.REEF.Wake.Remote.Impl; + +namespace Org.Apache.REEF.Network.Tests.GroupCommunication +{ + [TestClass] + public class GroupCommunicationTests + { + [TestMethod] + public void TestSender() + { + using (NameServer nameServer = new NameServer(0)) + { + IPEndPoint endpoint = nameServer.LocalEndpoint; + BlockingCollection<GroupCommunicationMessage> messages1 = new BlockingCollection<GroupCommunicationMessage>(); + BlockingCollection<GroupCommunicationMessage> messages2 = new BlockingCollection<GroupCommunicationMessage>(); + + var handler1 = Observer.Create<NsMessage<GroupCommunicationMessage>>( + msg => messages1.Add(msg.Data.First())); + var handler2 = Observer.Create<NsMessage<GroupCommunicationMessage>>( + msg => messages2.Add(msg.Data.First())); + + var networkService1 = BuildNetworkService(endpoint, handler1); + var networkService2 = BuildNetworkService(endpoint, handler2); + + networkService1.Register(new StringIdentifier("id1")); + networkService2.Register(new StringIdentifier("id2")); + + Sender sender1 = new Sender(networkService1, new StringIdentifierFactory()); + Sender sender2 = new Sender(networkService2, new StringIdentifierFactory()); + + sender1.Send(CreateGcm("abc", "id1", "id2")); + sender1.Send(CreateGcm("def", "id1", "id2")); + + sender2.Send(CreateGcm("ghi", "id2", "id1")); + + string msg1 = Encoding.UTF8.GetString(messages2.Take().Data[0]); + string msg2 = Encoding.UTF8.GetString(messages2.Take().Data[0]); + Assert.AreEqual("abc", msg1); + Assert.AreEqual("def", msg2); + + string msg3 = Encoding.UTF8.GetString(messages1.Take().Data[0]); + Assert.AreEqual("ghi", msg3); + } + } + + [TestMethod] + public void TestBroadcastReduceOperators() + { + string groupName = "group1"; + string broadcastOperatorName = "broadcast"; + string reduceOperatorName = "reduce"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 3; + int fanOut = 2; + + var mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddBroadcast<int, IntCodec>( + broadcastOperatorName, + masterTaskId) + .AddReduce<int, IntCodec>( + reduceOperatorName, + masterTaskId, + new SumFunction()) + .Build(); + + var commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + //for master task + IBroadcastSender<int> broadcastSender = commGroups[0].GetBroadcastSender<int>(broadcastOperatorName); + IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver1 = commGroups[1].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender1 = commGroups[1].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver2 = commGroups[2].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender2 = commGroups[2].GetReduceSender<int>(reduceOperatorName); + + for (int j = 1; j <= 10; j++) + { + broadcastSender.Send(j); + + int n1 = broadcastReceiver1.Receive(); + int n2 = broadcastReceiver2.Receive(); + Assert.AreEqual(j, n1); + Assert.AreEqual(j, n2); + + int triangleNum1 = TriangleNumber(n1); + triangleNumberSender1.Send(triangleNum1); + int triangleNum2 = TriangleNumber(n2); + triangleNumberSender2.Send(triangleNum2); + + int sum = sumReducer.Reduce(); + int expected = TriangleNumber(j) * (numTasks - 1); + Assert.AreEqual(sum, expected); + } + } + + [TestMethod] + public void TestScatterReduceOperators() + { + string groupName = "group1"; + string scatterOperatorName = "scatter"; + string reduceOperatorName = "reduce"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>( + scatterOperatorName, + masterTaskId) + .AddReduce<int, IntCodec>( + reduceOperatorName, + masterTaskId, + new SumFunction()) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(scatterOperatorName); + IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); + + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender1 = commGroups[1].GetReduceSender<int>(reduceOperatorName); + + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender2 = commGroups[2].GetReduceSender<int>(reduceOperatorName); + + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender3 = commGroups[3].GetReduceSender<int>(reduceOperatorName); + + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender4 = commGroups[4].GetReduceSender<int>(reduceOperatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = Enumerable.Range(1, 100).ToList(); + List<string> order = new List<string> {"task4", "task3", "task2", "task1"}; + + sender.Send(data, order); + + ScatterReceiveReduce(receiver4, sumSender4); + ScatterReceiveReduce(receiver3, sumSender3); + ScatterReceiveReduce(receiver2, sumSender2); + ScatterReceiveReduce(receiver1, sumSender1); + + int sum = sumReducer.Reduce(); + + Assert.AreEqual(sum, data.Sum()); + } + + [TestMethod] + public void TestBroadcastOperator() + { + NameServer nameServer = new NameServer(0); + + string groupName = "group1"; + string operatorName = "broadcast"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 10; + int value = 1337; + int fanOut = 3; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddBroadcast<int, IntCodec>(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); + IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + + sender.Send(value); + Assert.AreEqual(value, receiver1.Receive()); + Assert.AreEqual(value, receiver2.Receive()); + } + + [TestMethod] + public void TestBroadcastOperatorWithDefaultCodec() + { + NameServer nameServer = new NameServer(0); + + string groupName = "group1"; + string operatorName = "broadcast"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 10; + int value = 1337; + int fanOut = 3; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddBroadcast(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); + IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + + sender.Send(value); + Assert.AreEqual(value, receiver1.Receive()); + Assert.AreEqual(value, receiver2.Receive()); + } + + [TestMethod] + public void TestBroadcastOperator2() + { + string groupName = "group1"; + string operatorName = "broadcast"; + string driverId = "driverId"; + string masterTaskId = "task0"; + int numTasks = 3; + int value1 = 1337; + int value2 = 42; + int value3 = 99; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddBroadcast<int, IntCodec>(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); + IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + + sender.Send(value1); + Assert.AreEqual(value1, receiver1.Receive()); + Assert.AreEqual(value1, receiver2.Receive()); + + sender.Send(value2); + Assert.AreEqual(value2, receiver1.Receive()); + Assert.AreEqual(value2, receiver2.Receive()); + + sender.Send(value3); + Assert.AreEqual(value3, receiver1.Receive()); + Assert.AreEqual(value3, receiver2.Receive()); + } + + [TestMethod] + public void TestReduceOperator() + { + string groupName = "group1"; + string operatorName = "reduce"; + int numTasks = 4; + string driverId = "driverid"; + string masterTaskId = "task0"; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddReduce<int, IntCodec>(operatorName, "task0", new SumFunction()) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); + IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); + IReduceSender<int> sender2 = commGroups[2].GetReduceSender<int>(operatorName); + IReduceSender<int> sender3 = commGroups[3].GetReduceSender<int>(operatorName); + + Assert.IsNotNull(receiver); + Assert.IsNotNull(sender1); + Assert.IsNotNull(sender2); + Assert.IsNotNull(sender3); + + sender3.Send(5); + sender1.Send(1); + sender2.Send(3); + + Assert.AreEqual(9, receiver.Reduce()); + } + + [TestMethod] + public void TestReduceOperator2() + { + string groupName = "group1"; + string operatorName = "reduce"; + int numTasks = 4; + string driverId = "driverid"; + string masterTaskId = "task0"; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddReduce<int, IntCodec>(operatorName, "task0", new SumFunction()) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); + IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); + IReduceSender<int> sender2 = commGroups[2].GetReduceSender<int>(operatorName); + IReduceSender<int> sender3 = commGroups[3].GetReduceSender<int>(operatorName); + + Assert.IsNotNull(receiver); + Assert.IsNotNull(sender1); + Assert.IsNotNull(sender2); + Assert.IsNotNull(sender3); + + sender3.Send(5); + sender1.Send(1); + sender2.Send(3); + Assert.AreEqual(9, receiver.Reduce()); + + sender3.Send(6); + sender1.Send(2); + sender2.Send(4); + Assert.AreEqual(12, receiver.Reduce()); + + sender3.Send(9); + sender1.Send(3); + sender2.Send(6); + Assert.AreEqual(18, receiver.Reduce()); + } + + [TestMethod] + public void TestScatterOperator() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = new List<int> { 1, 2, 3, 4 }; + + sender.Send(data); + Assert.AreEqual(1, receiver1.Receive().Single()); + Assert.AreEqual(2, receiver2.Receive().Single()); + Assert.AreEqual(3, receiver3.Receive().Single()); + Assert.AreEqual(4, receiver4.Receive().Single()); + } + + [TestMethod] + public void TestScatterOperatorWithDefaultCodec() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = new List<int> { 1, 2, 3, 4 }; + + sender.Send(data); + Assert.AreEqual(1, receiver1.Receive().Single()); + Assert.AreEqual(2, receiver2.Receive().Single()); + Assert.AreEqual(3, receiver3.Receive().Single()); + Assert.AreEqual(4, receiver4.Receive().Single()); + } + + [TestMethod] + public void TestScatterOperator2() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + + sender.Send(data); + var data1 = receiver1.Receive(); + Assert.AreEqual(1, data1.First()); + Assert.AreEqual(2, data1.Last()); + + var data2 = receiver2.Receive(); + Assert.AreEqual(3, data2.First()); + Assert.AreEqual(4, data2.Last()); + + var data3 = receiver3.Receive(); + Assert.AreEqual(5, data3.First()); + Assert.AreEqual(6, data3.Last()); + + var data4 = receiver4.Receive(); + Assert.AreEqual(7, data4.First()); + Assert.AreEqual(8, data4.Last()); + } + + [TestMethod] + public void TestScatterOperator3() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 4; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + + sender.Send(data); + + var data1 = receiver1.Receive().ToArray(); + Assert.AreEqual(1, data1[0]); + Assert.AreEqual(2, data1[1]); + Assert.AreEqual(3, data1[2]); + + var data2 = receiver2.Receive().ToArray(); + Assert.AreEqual(4, data2[0]); + Assert.AreEqual(5, data2[1]); + Assert.AreEqual(6, data2[2]); + + var data3 = receiver3.Receive().ToArray(); + Assert.AreEqual(7, data3[0]); + Assert.AreEqual(8, data3[1]); + } + + [TestMethod] + public void TestScatterOperator4() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 4; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + List<string> order = new List<string> { "task3", "task2", "task1" }; + + sender.Send(data, order); + + var data3 = receiver3.Receive().ToArray(); + Assert.AreEqual(1, data3[0]); + Assert.AreEqual(2, data3[1]); + Assert.AreEqual(3, data3[2]); + + var data2 = receiver2.Receive().ToArray(); + Assert.AreEqual(4, data2[0]); + Assert.AreEqual(5, data2[1]); + Assert.AreEqual(6, data2[2]); + + var data1 = receiver1.Receive().ToArray(); + Assert.AreEqual(7, data1[0]); + Assert.AreEqual(8, data1[1]); + } + + [TestMethod] + public void TestConfigurationBroadcastSpec() + { + FlatTopology<int, IntCodec> topology = new FlatTopology<int, IntCodec>("Operator", "Operator", "task1", "driverid", + new BroadcastOperatorSpec<int, IntCodec>("Sender")); + + topology.AddTask("task1"); + var conf = topology.GetTaskConfiguration("task1"); + + ICodec<int> codec = TangFactory.GetTang().NewInjector(conf).GetInstance<ICodec<int>>(); + Assert.AreEqual(3, codec.Decode(codec.Encode(3))); + } + + [TestMethod] + public void TestConfigurationReduceSpec() + { + FlatTopology<int, IntCodec> topology = new FlatTopology<int, IntCodec>("Operator", "Group", "task1", "driverid", + new ReduceOperatorSpec<int, IntCodec>("task1", new SumFunction())); + + topology.AddTask("task1"); + var conf2 = topology.GetTaskConfiguration("task1"); + + IReduceFunction<int> reduceFunction = TangFactory.GetTang().NewInjector(conf2).GetInstance<IReduceFunction<int>>(); + Assert.AreEqual(10, reduceFunction.Reduce(new int[] { 1, 2, 3, 4 })); + } + + public static IMpiDriver GetInstanceOfMpiDriver(string driverId, string masterTaskId, string groupName, int fanOut, int numTasks) + { + var c = TangFactory.GetTang().NewConfigurationBuilder() + .BindStringNamedParam<MpiConfigurationOptions.DriverId>(driverId) + .BindStringNamedParam<MpiConfigurationOptions.MasterTaskId>(masterTaskId) + .BindStringNamedParam<MpiConfigurationOptions.GroupName>(groupName) + .BindIntNamedParam<MpiConfigurationOptions.FanOut>(fanOut.ToString()) + .BindIntNamedParam<MpiConfigurationOptions.NumberOfTasks>(numTasks.ToString()) + .BindImplementation(GenericType<IConfigurationSerializer>.Class, GenericType<AvroConfigurationSerializer>.Class) + .Build(); + + IMpiDriver mpiDriver = TangFactory.GetTang().NewInjector(c).GetInstance<MpiDriver>(); + return mpiDriver; + } + + public static List<ICommunicationGroupClient> CommGroupClients(string groupName, int numTasks, IMpiDriver mpiDriver, ICommunicationGroupDriver commGroup) + { + List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); + IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); + + List<IConfiguration> partialConfigs = new List<IConfiguration>(); + for (int i = 0; i < numTasks; i++) + { + string taskId = "task" + i; + IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( + TaskConfiguration.ConfigurationModule + .Set(TaskConfiguration.Identifier, taskId) + .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) + .Build()) + .Build(); + commGroup.AddTask(taskId); + partialConfigs.Add(partialTaskConfig); + } + + for (int i = 0; i < numTasks; i++) + { + string taskId = "task" + i; + IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); + IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); + IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + + IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); + commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); + } + return commGroups; + } + + public static NetworkService<GroupCommunicationMessage> BuildNetworkService( + IPEndPoint nameServerEndpoint, IObserver<NsMessage<GroupCommunicationMessage>> handler) + { + return new NetworkService<GroupCommunicationMessage>( + 0, handler, new StringIdentifierFactory(), new GroupCommunicationMessageCodec(), new NameClient(nameServerEndpoint.Address.ToString(), nameServerEndpoint.Port)); + } + + private GroupCommunicationMessage CreateGcm(string message, string from, string to) + { + byte[] data = Encoding.UTF8.GetBytes(message); + return new GroupCommunicationMessage("g1", "op1", from, to, data, MessageType.Data); + } + + private static void ScatterReceiveReduce(IScatterReceiver<int> receiver, IReduceSender<int> sumSender) + { + List<int> data1 = receiver.Receive(); + int sum1 = data1.Sum(); + sumSender.Send(sum1); + } + + public static int TriangleNumber(int n) + { + return Enumerable.Range(1, n).Sum(); + } + } + + public class SumFunction : IReduceFunction<int> + { + [Inject] + public SumFunction() + { + } + + public int Reduce(IEnumerable<int> elements) + { + return elements.Sum(); + } + } + + public class MyTask : ITask + { + public void Dispose() + { + throw new NotImplementedException(); + } + + public byte[] Call(byte[] memento) + { + throw new NotImplementedException(); + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/987d4f37/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTreeTopologyTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTreeTopologyTests.cs b/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTreeTopologyTests.cs new file mode 100644 index 0000000..b244f33 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTreeTopologyTests.cs @@ -0,0 +1,634 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Org.Apache.REEF.Network.Group.Driver; +using Org.Apache.REEF.Network.Group.Operators; +using Org.Apache.REEF.Network.Group.Operators.Impl; +using Org.Apache.REEF.Network.Group.Topology; +using Org.Apache.REEF.Wake.Remote.Impl; + +namespace Org.Apache.REEF.Network.Tests.GroupCommunication +{ + [TestClass] + public class GroupCommunicationTreeTopologyTests + { + [TestMethod] + public void TestTreeTopology() + { + TreeTopology<int, IntCodec> topology = new TreeTopology<int, IntCodec>("Operator", "Operator", "task1", "driverid", + new BroadcastOperatorSpec<int, IntCodec>("task1"), 2); + for (int i = 1; i < 8; i++) + { + string taskid = "task" + i; + topology.AddTask(taskid); + } + + for (int i = 1; i < 8; i++) + { + var conf = topology.GetTaskConfiguration("task" + i); + } + } + + [TestMethod] + public void TestReduceOperator() + { + string groupName = "group1"; + string operatorName = "reduce"; + int numTasks = 10; + string driverId = "driverId"; + string masterTaskId = "task0"; + int fanOut = 3; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddReduce<int, IntCodec>(operatorName, masterTaskId, new SumFunction(), TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); + IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); + IReduceSender<int> sender2 = commGroups[2].GetReduceSender<int>(operatorName); + IReduceSender<int> sender3 = commGroups[3].GetReduceSender<int>(operatorName); + IReduceSender<int> sender4 = commGroups[4].GetReduceSender<int>(operatorName); + IReduceSender<int> sender5 = commGroups[5].GetReduceSender<int>(operatorName); + IReduceSender<int> sender6 = commGroups[6].GetReduceSender<int>(operatorName); + IReduceSender<int> sender7 = commGroups[7].GetReduceSender<int>(operatorName); + IReduceSender<int> sender8 = commGroups[8].GetReduceSender<int>(operatorName); + IReduceSender<int> sender9 = commGroups[9].GetReduceSender<int>(operatorName); + + Assert.IsNotNull(receiver); + Assert.IsNotNull(sender1); + Assert.IsNotNull(sender2); + Assert.IsNotNull(sender3); + Assert.IsNotNull(sender4); + Assert.IsNotNull(sender5); + Assert.IsNotNull(sender6); + Assert.IsNotNull(sender7); + Assert.IsNotNull(sender8); + Assert.IsNotNull(sender9); + + sender9.Send(9); + sender8.Send(8); + sender7.Send(7); + sender6.Send(6); + sender5.Send(5); + sender4.Send(4); + sender3.Send(3); + sender2.Send(2); + sender1.Send(1); + + Assert.AreEqual(45, receiver.Reduce()); + } + + [TestMethod] + public void TestBroadcastOperator() + { + string groupName = "group1"; + string operatorName = "broadcast"; + string driverId = "driverId"; + string masterTaskId = "task0"; + int numTasks = 10; + int value1 = 1337; + int value2 = 42; + int value3 = 99; + int fanOut = 3; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddBroadcast<int, IntCodec>(operatorName, masterTaskId, TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); + IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver3 = commGroups[3].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver4 = commGroups[4].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver5 = commGroups[5].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver6 = commGroups[6].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver7 = commGroups[7].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver8 = commGroups[8].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver9 = commGroups[9].GetBroadcastReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + Assert.IsNotNull(receiver5); + Assert.IsNotNull(receiver6); + Assert.IsNotNull(receiver7); + Assert.IsNotNull(receiver8); + Assert.IsNotNull(receiver9); + + sender.Send(value1); + Assert.AreEqual(value1, receiver1.Receive()); + Assert.AreEqual(value1, receiver2.Receive()); + Assert.AreEqual(value1, receiver3.Receive()); + Assert.AreEqual(value1, receiver4.Receive()); + Assert.AreEqual(value1, receiver5.Receive()); + Assert.AreEqual(value1, receiver6.Receive()); + Assert.AreEqual(value1, receiver7.Receive()); + Assert.AreEqual(value1, receiver8.Receive()); + Assert.AreEqual(value1, receiver9.Receive()); + + sender.Send(value2); + Assert.AreEqual(value2, receiver1.Receive()); + Assert.AreEqual(value2, receiver2.Receive()); + Assert.AreEqual(value2, receiver3.Receive()); + Assert.AreEqual(value2, receiver4.Receive()); + Assert.AreEqual(value2, receiver5.Receive()); + Assert.AreEqual(value2, receiver6.Receive()); + Assert.AreEqual(value2, receiver7.Receive()); + Assert.AreEqual(value2, receiver8.Receive()); + Assert.AreEqual(value2, receiver9.Receive()); + + sender.Send(value3); + Assert.AreEqual(value3, receiver1.Receive()); + Assert.AreEqual(value3, receiver2.Receive()); + Assert.AreEqual(value3, receiver3.Receive()); + Assert.AreEqual(value3, receiver4.Receive()); + Assert.AreEqual(value3, receiver5.Receive()); + Assert.AreEqual(value3, receiver6.Receive()); + Assert.AreEqual(value3, receiver7.Receive()); + Assert.AreEqual(value3, receiver8.Receive()); + Assert.AreEqual(value3, receiver9.Receive()); + } + + + [TestMethod] + public void TestBroadcastReduceOperators() + { + string groupName = "group1"; + string broadcastOperatorName = "broadcast"; + string reduceOperatorName = "reduce"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 10; + int fanOut = 3; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddBroadcast<int, IntCodec>( + broadcastOperatorName, + masterTaskId, + TopologyTypes.Tree) + .AddReduce<int, IntCodec>( + reduceOperatorName, + masterTaskId, + new SumFunction(), + TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + //for master task + IBroadcastSender<int> broadcastSender = commGroups[0].GetBroadcastSender<int>(broadcastOperatorName); + IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver1 = commGroups[1].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender1 = commGroups[1].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver2 = commGroups[2].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender2 = commGroups[2].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver3 = commGroups[3].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender3 = commGroups[3].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver4 = commGroups[4].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender4 = commGroups[4].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver5 = commGroups[5].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender5 = commGroups[5].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver6 = commGroups[6].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender6 = commGroups[6].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver7 = commGroups[7].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender7 = commGroups[7].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver8 = commGroups[8].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender8 = commGroups[8].GetReduceSender<int>(reduceOperatorName); + + IBroadcastReceiver<int> broadcastReceiver9 = commGroups[9].GetBroadcastReceiver<int>(broadcastOperatorName); + IReduceSender<int> triangleNumberSender9 = commGroups[9].GetReduceSender<int>(reduceOperatorName); + + for (int i = 1; i <= 10; i++) + { + broadcastSender.Send(i); + + int n1 = broadcastReceiver1.Receive(); + int n2 = broadcastReceiver2.Receive(); + int n3 = broadcastReceiver3.Receive(); + int n4 = broadcastReceiver4.Receive(); + int n5 = broadcastReceiver5.Receive(); + int n6 = broadcastReceiver6.Receive(); + int n7 = broadcastReceiver7.Receive(); + int n8 = broadcastReceiver8.Receive(); + int n9 = broadcastReceiver9.Receive(); + Assert.AreEqual(i, n1); + Assert.AreEqual(i, n2); + Assert.AreEqual(i, n3); + Assert.AreEqual(i, n4); + Assert.AreEqual(i, n5); + Assert.AreEqual(i, n6); + Assert.AreEqual(i, n7); + Assert.AreEqual(i, n8); + Assert.AreEqual(i, n9); + + int triangleNum9 = GroupCommunicationTests.TriangleNumber(n9); + triangleNumberSender9.Send(triangleNum9); + + int triangleNum8 = GroupCommunicationTests.TriangleNumber(n8); + triangleNumberSender8.Send(triangleNum8); + + int triangleNum7 = GroupCommunicationTests.TriangleNumber(n7); + triangleNumberSender7.Send(triangleNum7); + + int triangleNum6 = GroupCommunicationTests.TriangleNumber(n6); + triangleNumberSender6.Send(triangleNum6); + + int triangleNum5 = GroupCommunicationTests.TriangleNumber(n5); + triangleNumberSender5.Send(triangleNum5); + + int triangleNum4 = GroupCommunicationTests.TriangleNumber(n4); + triangleNumberSender4.Send(triangleNum4); + + int triangleNum3 = GroupCommunicationTests.TriangleNumber(n3); + triangleNumberSender3.Send(triangleNum3); + + int triangleNum2 = GroupCommunicationTests.TriangleNumber(n2); + triangleNumberSender2.Send(triangleNum2); + + int triangleNum1 = GroupCommunicationTests.TriangleNumber(n1); + triangleNumberSender1.Send(triangleNum1); + + int sum = sumReducer.Reduce(); + int expected = GroupCommunicationTests.TriangleNumber(i) * (numTasks - 1); + Assert.AreEqual(sum, expected); + } + } + + [TestMethod] + public void TestScatterOperator() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId, TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = new List<int> { 1, 2, 3, 4 }; + + sender.Send(data); + var receved1 = receiver1.Receive().ToArray(); + Assert.AreEqual(1, receved1[0]); + Assert.AreEqual(2, receved1[1]); + + var receved2 = receiver2.Receive().ToArray(); + Assert.AreEqual(3, receved2[0]); + Assert.AreEqual(4, receved2[1]); + + Assert.AreEqual(1, receiver3.Receive().Single()); + Assert.AreEqual(2, receiver4.Receive().Single()); + } + + [TestMethod] + public void TestScatterOperator2() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId, TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + + sender.Send(data); + var data1 = receiver1.Receive().ToArray(); + Assert.AreEqual(1, data1[0]); + Assert.AreEqual(2, data1[1]); + Assert.AreEqual(3, data1[2]); + Assert.AreEqual(4, data1[3]); + + var data2 = receiver2.Receive().ToArray(); + Assert.AreEqual(5, data2[0]); + Assert.AreEqual(6, data2[1]); + Assert.AreEqual(7, data2[2]); + Assert.AreEqual(8, data2[3]); + + var data3 = receiver3.Receive(); + Assert.AreEqual(1, data3.First()); + Assert.AreEqual(2, data3.Last()); + + var data4 = receiver4.Receive(); + Assert.AreEqual(3, data4.First()); + Assert.AreEqual(4, data4.Last()); + } + + [TestMethod] + public void TestScatterOperator3() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 4; + int fanOut = 2; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId, TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + + sender.Send(data); + + var data1 = receiver1.Receive().ToArray(); + Assert.AreEqual(1, data1[0]); + Assert.AreEqual(2, data1[1]); + Assert.AreEqual(3, data1[2]); + Assert.AreEqual(4, data1[3]); + + var data2 = receiver2.Receive().ToArray(); + Assert.AreEqual(5, data2[0]); + Assert.AreEqual(6, data2[1]); + Assert.AreEqual(7, data2[2]); + Assert.AreEqual(8, data2[3]); + + var data3 = receiver3.Receive().ToArray(); + Assert.AreEqual(1, data3[0]); + Assert.AreEqual(2, data3[1]); + Assert.AreEqual(3, data3[2]); + Assert.AreEqual(4, data3[3]); + } + + [TestMethod] + public void TestScatterOperator4() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 4; + int fanOut = 2; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId, TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + List<string> order = new List<string> { "task2", "task1" }; + + sender.Send(data, order); + + var data2 = receiver2.Receive().ToArray(); + Assert.AreEqual(1, data2[0]); + Assert.AreEqual(2, data2[1]); + Assert.AreEqual(3, data2[2]); + Assert.AreEqual(4, data2[3]); + + var data1 = receiver1.Receive().ToArray(); + Assert.AreEqual(5, data1[0]); + Assert.AreEqual(6, data1[1]); + Assert.AreEqual(7, data1[2]); + Assert.AreEqual(8, data1[3]); + + var data3 = receiver3.Receive().ToArray(); + Assert.AreEqual(5, data3[0]); + Assert.AreEqual(6, data3[1]); + Assert.AreEqual(7, data3[2]); + Assert.AreEqual(8, data3[3]); + } + + [TestMethod] + public void TestScatterOperator5() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 6; + int fanOut = 2; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>(operatorName, masterTaskId, TopologyTypes.Tree) + .Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver5 = commGroups[5].GetScatterReceiver<int>(operatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + Assert.IsNotNull(receiver5); + + List<int> data = new List<int> { 1, 2, 3, 4, 5, 6, 7, 8 }; + + sender.Send(data); + + var data1 = receiver1.Receive().ToArray(); + Assert.AreEqual(1, data1[0]); + Assert.AreEqual(2, data1[1]); + Assert.AreEqual(3, data1[2]); + Assert.AreEqual(4, data1[3]); + + var data2 = receiver2.Receive().ToArray(); + Assert.AreEqual(5, data2[0]); + Assert.AreEqual(6, data2[1]); + Assert.AreEqual(7, data2[2]); + Assert.AreEqual(8, data2[3]); + + var data3 = receiver3.Receive().ToArray(); + Assert.AreEqual(1, data3[0]); + Assert.AreEqual(2, data3[1]); + + var data4= receiver4.Receive().ToArray(); + Assert.AreEqual(3, data4[0]); + Assert.AreEqual(4, data4[1]); + + var data5 = receiver5.Receive().ToArray(); + Assert.AreEqual(5, data5[0]); + Assert.AreEqual(6, data5[1]); + Assert.AreEqual(7, data5[2]); + Assert.AreEqual(8, data5[3]); + } + + [TestMethod] + public void TestScatterReduceOperators() + { + string groupName = "group1"; + string scatterOperatorName = "scatter"; + string reduceOperatorName = "reduce"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + var mpiDriver = GroupCommunicationTests.GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int, IntCodec>( + scatterOperatorName, + masterTaskId, + TopologyTypes.Tree) + .AddReduce<int, IntCodec>( + reduceOperatorName, + masterTaskId, + new SumFunction(), + TopologyTypes.Tree).Build(); + + var commGroups = GroupCommunicationTests.CommGroupClients(groupName, numTasks, mpiDriver, commGroup); + + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(scatterOperatorName); + IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); + + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender1 = commGroups[1].GetReduceSender<int>(reduceOperatorName); + + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender2 = commGroups[2].GetReduceSender<int>(reduceOperatorName); + + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender3 = commGroups[3].GetReduceSender<int>(reduceOperatorName); + + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(scatterOperatorName); + IReduceSender<int> sumSender4 = commGroups[4].GetReduceSender<int>(reduceOperatorName); + + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); + + List<int> data = Enumerable.Range(1, 100).ToList(); + + sender.Send(data); + + List<int> data1 = receiver1.Receive(); + List<int> data2 = receiver2.Receive(); + + List<int> data3 = receiver3.Receive(); + List<int> data4 = receiver4.Receive(); + + int sum3 = data3.Sum(); + sumSender3.Send(sum3); + + int sum4 = data4.Sum(); + sumSender4.Send(sum4); + + int sum2 = data2.Sum(); + sumSender2.Send(sum2); + + int sum1 = data1.Sum(); + sumSender1.Send(sum1); + + int sum = sumReducer.Reduce(); + Assert.AreEqual(sum, 6325); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/987d4f37/lang/cs/Org.Apache.REEF.Network.Tests/NamingService/NameServerTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/NamingService/NameServerTests.cs b/lang/cs/Org.Apache.REEF.Network.Tests/NamingService/NameServerTests.cs new file mode 100644 index 0000000..fd3002c --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network.Tests/NamingService/NameServerTests.cs @@ -0,0 +1,269 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Threading; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Org.Apache.REEF.Common.Io; +using Org.Apache.REEF.Network.Naming; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Tang.Implementations.Tang; +using Org.Apache.REEF.Tang.Interface; +using Org.Apache.REEF.Tang.Util; + +namespace Org.Apache.REEF.Network.Tests.NamingService +{ + [TestClass] + public class NameServerTests + { + [TestMethod] + public void TestNameServerNoRequests() + { + using (var server = new NameServer(0)) + { + } + } + + [TestMethod] + public void TestNameServerNoRequestsTwoClients() + { + using (var server = new NameServer(0)) + { + var nameClient = new NameClient(server.LocalEndpoint); + var nameClient2 = new NameClient(server.LocalEndpoint); + nameClient2.Register("1", new IPEndPoint(IPAddress.Any, 8080)); + nameClient.Lookup("1"); + } + } + + [TestMethod] + public void TestNameServerNoRequestsTwoClients2() + { + using (var server = new NameServer(0)) + { + var nameClient = new NameClient(server.LocalEndpoint); + var nameClient2 = new NameClient(server.LocalEndpoint); + nameClient2.Register("1", new IPEndPoint(IPAddress.Any, 8080)); + nameClient.Lookup("1"); + } + } + + [TestMethod] + public void TestNameServerMultipleRequestsTwoClients() + { + using (var server = new NameServer(0)) + { + var nameClient = new NameClient(server.LocalEndpoint); + var nameClient2 = new NameClient(server.LocalEndpoint); + nameClient.Register("1", new IPEndPoint(IPAddress.Any, 8080)); + nameClient2.Lookup("1"); + } + } + + [TestMethod] + public void TestRegister() + { + using (INameServer server = BuildNameServer()) + { + using (INameClient client = BuildNameClient(server.LocalEndpoint)) + { + IPEndPoint endpoint1 = new IPEndPoint(IPAddress.Parse("100.0.0.1"), 100); + IPEndPoint endpoint2 = new IPEndPoint(IPAddress.Parse("100.0.0.2"), 200); + IPEndPoint endpoint3 = new IPEndPoint(IPAddress.Parse("100.0.0.3"), 300); + + // Check that no endpoints have been registered + Assert.IsNull(client.Lookup("a")); + Assert.IsNull(client.Lookup("b")); + Assert.IsNull(client.Lookup("c")); + + // Register endpoints + client.Register("a", endpoint1); + client.Register("b", endpoint2); + client.Register("c", endpoint3); + + // Check that they can be looked up correctly + Assert.AreEqual(endpoint1, client.Lookup("a")); + Assert.AreEqual(endpoint2, client.Lookup("b")); + Assert.AreEqual(endpoint3, client.Lookup("c")); + } + } + } + + [TestMethod] + public void TestUnregister() + { + using (INameServer server = BuildNameServer()) + { + using (INameClient client = BuildNameClient(server.LocalEndpoint)) + { + IPEndPoint endpoint1 = new IPEndPoint(IPAddress.Parse("100.0.0.1"), 100); + + // Register endpoint + client.Register("a", endpoint1); + + // Check that it can be looked up correctly + Assert.AreEqual(endpoint1, client.Lookup("a")); + + // Unregister endpoints + client.Unregister("a"); + Thread.Sleep(1000); + + // Make sure they were unregistered correctly + Assert.IsNull(client.Lookup("a")); + } + } + } + + [TestMethod] + public void TestLookup() + { + using (INameServer server = BuildNameServer()) + { + using (INameClient client = BuildNameClient(server.LocalEndpoint)) + { + IPEndPoint endpoint1 = new IPEndPoint(IPAddress.Parse("100.0.0.1"), 100); + IPEndPoint endpoint2 = new IPEndPoint(IPAddress.Parse("100.0.0.2"), 200); + + // Register endpoint1 + client.Register("a", endpoint1); + Assert.AreEqual(endpoint1, client.Lookup("a")); + + // Reregister identifer a + client.Register("a", endpoint2); + Assert.AreEqual(endpoint2, client.Lookup("a")); + } + } + } + + [TestMethod] + public void TestLookupList() + { + using (INameServer server = BuildNameServer()) + { + using (INameClient client = BuildNameClient(server.LocalEndpoint)) + { + IPEndPoint endpoint1 = new IPEndPoint(IPAddress.Parse("100.0.0.1"), 100); + IPEndPoint endpoint2 = new IPEndPoint(IPAddress.Parse("100.0.0.2"), 200); + IPEndPoint endpoint3 = new IPEndPoint(IPAddress.Parse("100.0.0.3"), 300); + + // Register endpoints + client.Register("a", endpoint1); + client.Register("b", endpoint2); + client.Register("c", endpoint3); + + // Look up both at the same time + List<string> ids = new List<string> { "a", "b", "c", "d" }; + List<NameAssignment> assignments = client.Lookup(ids); + + // Check that a, b, and c are registered + Assert.AreEqual("a", assignments[0].Identifier); + Assert.AreEqual(endpoint1, assignments[0].Endpoint); + Assert.AreEqual("b", assignments[1].Identifier); + Assert.AreEqual(endpoint2, assignments[1].Endpoint); + Assert.AreEqual("c", assignments[2].Identifier); + Assert.AreEqual(endpoint3, assignments[2].Endpoint); + + // Check that d is not registered + Assert.AreEqual(3, assignments.Count); + } + } + } + + [TestMethod] + public void TestNameClientRestart() + { + int oldPort = 6666; + int newPort = 6662; + INameServer server = new NameServer(oldPort); + + using (INameClient client = BuildNameClient(server.LocalEndpoint)) + { + IPEndPoint endpoint = new IPEndPoint(IPAddress.Parse("100.0.0.1"), 100); + + client.Register("a", endpoint); + Assert.AreEqual(endpoint, client.Lookup("a")); + + server.Dispose(); + + server = new NameServer(newPort); + client.Restart(server.LocalEndpoint); + + client.Register("b", endpoint); + Assert.AreEqual(endpoint, client.Lookup("b")); + + server.Dispose(); + } + } + + [TestMethod] + public void TestConstructorInjection() + { + int port = 6666; + using (INameServer server = new NameServer(port)) + { + IConfiguration nameClientConfiguration = NamingConfiguration.ConfigurationModule + .Set(NamingConfiguration.NameServerAddress, server.LocalEndpoint.Address.ToString()) + .Set(NamingConfiguration.NameServerPort, port + string.Empty) + .Build(); + + ConstructorInjection c = TangFactory.GetTang() + .NewInjector(nameClientConfiguration) + .GetInstance<ConstructorInjection>(); + + Assert.IsNotNull(c); + } + } + + private INameServer BuildNameServer() + { + var builder = TangFactory.GetTang() + .NewConfigurationBuilder() + .BindNamedParameter<NamingConfigurationOptions.NameServerPort, int>( + GenericType<NamingConfigurationOptions.NameServerPort>.Class, "0"); + + return TangFactory.GetTang().NewInjector(builder.Build()).GetInstance<INameServer>(); + } + + private INameClient BuildNameClient(IPEndPoint remoteEndpoint) + { + string nameServerAddr = remoteEndpoint.Address.ToString(); + int nameServerPort = remoteEndpoint.Port; + IConfiguration nameClientConfiguration = NamingConfiguration.ConfigurationModule + .Set(NamingConfiguration.NameServerAddress, nameServerAddr) + .Set(NamingConfiguration.NameServerPort, nameServerPort + string.Empty) + .Build(); + + return TangFactory.GetTang().NewInjector(nameClientConfiguration).GetInstance<NameClient>(); + } + + private class ConstructorInjection + { + [Inject] + public ConstructorInjection(NameClient client) + { + if (client == null) + { + throw new ArgumentNullException("client"); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/987d4f37/lang/cs/Org.Apache.REEF.Network.Tests/NetworkService/NetworkServiceTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/NetworkService/NetworkServiceTests.cs b/lang/cs/Org.Apache.REEF.Network.Tests/NetworkService/NetworkServiceTests.cs new file mode 100644 index 0000000..1489b3c --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network.Tests/NetworkService/NetworkServiceTests.cs @@ -0,0 +1,210 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Globalization; +using System.Linq; +using System.Net; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Org.Apache.REEF.Common.Io; +using Org.Apache.REEF.Network.Naming; +using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Tang.Implementations.Tang; +using Org.Apache.REEF.Tang.Util; +using Org.Apache.REEF.Wake; +using Org.Apache.REEF.Wake.Remote; +using Org.Apache.REEF.Wake.Remote.Impl; +using Org.Apache.REEF.Wake.Util; + +namespace Org.Apache.REEF.Network.Tests.NetworkService +{ + [TestClass] + public class NetworkServiceTests + { + [TestMethod] + public void TestNetworkServiceOneWayCommunication() + { + int networkServicePort1 = NetworkUtils.GenerateRandomPort(6000, 7000); + int networkServicePort2 = NetworkUtils.GenerateRandomPort(7001, 8000); + + BlockingCollection<string> queue = new BlockingCollection<string>(); + + using (INameServer nameServer = new NameServer(0)) + { + IPEndPoint endpoint = nameServer.LocalEndpoint; + int nameServerPort = endpoint.Port; + string nameServerAddr = endpoint.Address.ToString(); + using (INetworkService<string> networkService1 = BuildNetworkService(networkServicePort1, nameServerPort, nameServerAddr, null)) + using (INetworkService<string> networkService2 = BuildNetworkService(networkServicePort2, nameServerPort, nameServerAddr, new MessageHandler(queue))) + { + IIdentifier id1 = new StringIdentifier("service1"); + IIdentifier id2 = new StringIdentifier("service2"); + networkService1.Register(id1); + networkService2.Register(id2); + + using (IConnection<string> connection = networkService1.NewConnection(id2)) + { + connection.Open(); + connection.Write("abc"); + connection.Write("def"); + connection.Write("ghi"); + + Assert.AreEqual("abc", queue.Take()); + Assert.AreEqual("def", queue.Take()); + Assert.AreEqual("ghi", queue.Take()); + } + } + } + } + + [TestMethod] + public void TestNetworkServiceTwoWayCommunication() + { + int networkServicePort1 = NetworkUtils.GenerateRandomPort(6000, 7000); + int networkServicePort2 = NetworkUtils.GenerateRandomPort(7001, 8000); + + BlockingCollection<string> queue1 = new BlockingCollection<string>(); + BlockingCollection<string> queue2 = new BlockingCollection<string>(); + + using (INameServer nameServer = new NameServer(0)) + { + IPEndPoint endpoint = nameServer.LocalEndpoint; + int nameServerPort = endpoint.Port; + string nameServerAddr = endpoint.Address.ToString(); + using (INetworkService<string> networkService1 = BuildNetworkService(networkServicePort1, nameServerPort, nameServerAddr, new MessageHandler(queue1))) + using (INetworkService<string> networkService2 = BuildNetworkService(networkServicePort2, nameServerPort, nameServerAddr, new MessageHandler(queue2))) + { + IIdentifier id1 = new StringIdentifier("service1"); + IIdentifier id2 = new StringIdentifier("service2"); + networkService1.Register(id1); + networkService2.Register(id2); + + using (IConnection<string> connection1 = networkService1.NewConnection(id2)) + using (IConnection<string> connection2 = networkService2.NewConnection(id1)) + { + connection1.Open(); + connection1.Write("abc"); + connection1.Write("def"); + connection1.Write("ghi"); + + connection2.Open(); + connection2.Write("jkl"); + connection2.Write("mno"); + + Assert.AreEqual("abc", queue2.Take()); + Assert.AreEqual("def", queue2.Take()); + Assert.AreEqual("ghi", queue2.Take()); + + Assert.AreEqual("jkl", queue1.Take()); + Assert.AreEqual("mno", queue1.Take()); + } + } + } + } + + private INetworkService<string> BuildNetworkService( + int networkServicePort, + int nameServicePort, + string nameServiceAddr, + IObserver<NsMessage<string>> handler) + { + // Test injection + if (handler == null) + { + var networkServiceConf = TangFactory.GetTang().NewConfigurationBuilder() + .BindNamedParameter<NetworkServiceOptions.NetworkServicePort, int>( + GenericType<NetworkServiceOptions.NetworkServicePort>.Class, + networkServicePort.ToString(CultureInfo.CurrentCulture)) + .BindNamedParameter<NamingConfigurationOptions.NameServerPort, int>( + GenericType<NamingConfigurationOptions.NameServerPort>.Class, + nameServicePort.ToString(CultureInfo.CurrentCulture)) + .BindNamedParameter<NamingConfigurationOptions.NameServerAddress, string>( + GenericType<NamingConfigurationOptions.NameServerAddress>.Class, + nameServiceAddr) + .BindImplementation(GenericType<INameClient>.Class, GenericType<NameClient>.Class) + .BindImplementation(GenericType<ICodec<string>>.Class, GenericType<StringCodec>.Class) + .BindImplementation(GenericType<IObserver<NsMessage<string>>>.Class, GenericType<NetworkMessageHandler>.Class) + .Build(); + + return TangFactory.GetTang().NewInjector(networkServiceConf).GetInstance<NetworkService<string>>(); + } + + var nameserverConf = TangFactory.GetTang().NewConfigurationBuilder() + .BindNamedParameter<NamingConfigurationOptions.NameServerPort, int>( + GenericType<NamingConfigurationOptions.NameServerPort>.Class, + nameServicePort.ToString(CultureInfo.CurrentCulture)) + .BindNamedParameter<NamingConfigurationOptions.NameServerAddress, string>( + GenericType<NamingConfigurationOptions.NameServerAddress>.Class, + nameServiceAddr) + .BindImplementation(GenericType<INameClient>.Class, GenericType<NameClient>.Class) + .Build(); + + var nameClient = TangFactory.GetTang().NewInjector(nameserverConf).GetInstance<NameClient>(); + return new NetworkService<string>(networkServicePort, + handler, new StringIdentifierFactory(), new StringCodec(), nameClient); + } + + private class MessageHandler : IObserver<NsMessage<string>> + { + private readonly BlockingCollection<string> _queue; + + public MessageHandler(BlockingCollection<string> queue) + { + _queue = queue; + } + + public void OnNext(NsMessage<string> value) + { + _queue.Add(value.Data.First()); + } + + public void OnError(Exception error) + { + throw new NotImplementedException(); + } + + public void OnCompleted() + { + throw new NotImplementedException(); + } + } + + private class NetworkMessageHandler : IObserver<NsMessage<string>> + { + [Inject] + public NetworkMessageHandler() + { + } + + public void OnNext(NsMessage<string> value) + { + } + + public void OnError(Exception error) + { + } + + public void OnCompleted() + { + } + } + } +}
