Move code from Enumerables to EnumerableDefaults
Project: http://git-wip-us.apache.org/repos/asf/calcite/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/b0a996b4 Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/b0a996b4 Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/b0a996b4 Branch: refs/heads/master Commit: b0a996b4105c238cd8455741b2fb4b0ba5256a49 Parents: b6f0e10 Author: Julian Hyde <[email protected]> Authored: Tue May 17 18:14:11 2016 -0700 Committer: Julian Hyde <[email protected]> Committed: Sun May 22 12:46:46 2016 -0700 ---------------------------------------------------------------------- .../org/apache/calcite/runtime/Enumerables.java | 285 +----------------- .../org/apache/calcite/util/BuiltInMethod.java | 15 +- .../apache/calcite/runtime/EnumerablesTest.java | 29 +- .../calcite/linq4j/EnumerableDefaults.java | 290 +++++++++++++++++-- 4 files changed, 296 insertions(+), 323 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite/blob/b0a996b4/core/src/main/java/org/apache/calcite/runtime/Enumerables.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/runtime/Enumerables.java b/core/src/main/java/org/apache/calcite/runtime/Enumerables.java index 22a85c7..0eba4f1 100644 --- a/core/src/main/java/org/apache/calcite/runtime/Enumerables.java +++ b/core/src/main/java/org/apache/calcite/runtime/Enumerables.java @@ -17,31 +17,17 @@ package org.apache.calcite.runtime; import org.apache.calcite.interpreter.Row; -import org.apache.calcite.linq4j.AbstractEnumerable; import org.apache.calcite.linq4j.Enumerable; -import org.apache.calcite.linq4j.Enumerator; -import org.apache.calcite.linq4j.Linq4j; -import org.apache.calcite.linq4j.function.EqualityComparer; import org.apache.calcite.linq4j.function.Function1; -import org.apache.calcite.linq4j.function.Function2; -import org.apache.calcite.linq4j.function.Predicate1; -import org.apache.calcite.linq4j.function.Predicate2; -import org.apache.calcite.util.Bug; import com.google.common.base.Supplier; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; - -import java.util.List; -import java.util.Set; /** * Utilities for processing {@link org.apache.calcite.linq4j.Enumerable} * collections. * * <p>This class is a place to put things not yet added to linq4j. - * Methods are subject to removal without notice.</p> + * Methods are subject to removal without notice. */ public class Enumerables { private static final Function1<?, ?> SLICE = @@ -67,137 +53,6 @@ public class Enumerables { return enumerable.select((Function1<E[], E>) SLICE); } - /** - * Returns elements of {@code outer} for which there is a member of - * {@code inner} with a matching key. - */ - public static <TSource, TInner, TKey> Enumerable<TSource> semiJoin( - final Enumerable<TSource> outer, final Enumerable<TInner> inner, - final Function1<TSource, TKey> outerKeySelector, - final Function1<TInner, TKey> innerKeySelector) { - Bug.upgrade("move into linq4j"); - return semiJoin(outer, inner, outerKeySelector, innerKeySelector, null); - } - - /** - * Returns elements of {@code outer} for which there is a member of - * {@code inner} with a matching key. A specified - * {@code EqualityComparer<TSource>} is used to compare keys. - */ - public static <TSource, TInner, TKey> Enumerable<TSource> semiJoin( - final Enumerable<TSource> outer, final Enumerable<TInner> inner, - final Function1<TSource, TKey> outerKeySelector, - final Function1<TInner, TKey> innerKeySelector, - final EqualityComparer<TKey> comparer) { - return new AbstractEnumerable<TSource>() { - public Enumerator<TSource> enumerator() { - final Enumerable<TKey> innerLookup = - comparer == null - ? inner.select(innerKeySelector).distinct() - : inner.select(innerKeySelector).distinct(comparer); - - return Enumerables.where(outer.enumerator(), - new Predicate1<TSource>() { - public boolean apply(TSource v0) { - final TKey key = outerKeySelector.apply(v0); - return innerLookup.contains(key); - } - }); - } - }; - } - - /** - * Correlates the elements of two sequences based on a predicate. - */ - public static <TSource, TInner, TResult> Enumerable<TResult> thetaJoin( - final Enumerable<TSource> outer, final Enumerable<TInner> inner, - final Predicate2<TSource, TInner> predicate, - Function2<TSource, TInner, TResult> resultSelector, - final boolean generateNullsOnLeft, - final boolean generateNullsOnRight) { - // Building the result as a list is easy but hogs memory. We should iterate. - final List<TResult> result = Lists.newArrayList(); - final Enumerator<TSource> lefts = outer.enumerator(); - final List<TInner> rightList = inner.toList(); - final Set<TInner> rightUnmatched; - if (generateNullsOnLeft) { - rightUnmatched = Sets.newIdentityHashSet(); - rightUnmatched.addAll(rightList); - } else { - rightUnmatched = null; - } - while (lefts.moveNext()) { - int leftMatchCount = 0; - final TSource left = lefts.current(); - final Enumerator<TInner> rights = Linq4j.iterableEnumerator(rightList); - while (rights.moveNext()) { - TInner right = rights.current(); - if (predicate.apply(left, right)) { - ++leftMatchCount; - if (rightUnmatched != null) { - rightUnmatched.remove(right); - } - result.add(resultSelector.apply(left, right)); - } - } - if (generateNullsOnRight && leftMatchCount == 0) { - result.add(resultSelector.apply(left, null)); - } - } - if (rightUnmatched != null) { - final Enumerator<TInner> rights = - Linq4j.iterableEnumerator(rightUnmatched); - while (rights.moveNext()) { - TInner right = rights.current(); - result.add(resultSelector.apply(null, right)); - } - } - return Linq4j.asEnumerable(result); - } - - /** - * Filters a sequence of values based on a - * predicate. - */ - public static <TSource> Enumerable<TSource> where( - final Enumerable<TSource> source, final Predicate1<TSource> predicate) { - assert predicate != null; - return new AbstractEnumerable<TSource>() { - public Enumerator<TSource> enumerator() { - final Enumerator<TSource> enumerator = source.enumerator(); - return Enumerables.where(enumerator, predicate); - } - }; - } - - private static <TSource> Enumerator<TSource> where( - final Enumerator<TSource> enumerator, - final Predicate1<TSource> predicate) { - return new Enumerator<TSource>() { - public TSource current() { - return enumerator.current(); - } - - public boolean moveNext() { - while (enumerator.moveNext()) { - if (predicate.apply(enumerator.current())) { - return true; - } - } - return false; - } - - public void reset() { - enumerator.reset(); - } - - public void close() { - enumerator.close(); - } - }; - } - /** Converts an {@link Enumerable} over object arrays into an * {@link Enumerable} over {@link Row} objects. */ public static Enumerable<Row> toRow(final Enumerable<Object[]> enumerable) { @@ -215,144 +70,6 @@ public class Enumerables { }; } - /** Joins two inputs that are sorted on the key. */ - public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> - Enumerable<TResult> mergeJoin(final Enumerable<TSource> outer, - final Enumerable<TInner> inner, - final Function1<TSource, TKey> outerKeySelector, - final Function1<TInner, TKey> innerKeySelector, - final Function2<TSource, TInner, TResult> resultSelector, - boolean generateNullsOnLeft, - boolean generateNullsOnRight) { - assert !generateNullsOnLeft : "not implemented"; - assert !generateNullsOnRight : "not implemented"; - return new AbstractEnumerable<TResult>() { - public Enumerator<TResult> enumerator() { - return new Enumerator<TResult>() { - final Enumerator<TSource> leftEnumerator = outer.enumerator(); - final Enumerator<TInner> rightEnumerator = inner.enumerator(); - final List<TSource> lefts = Lists.newArrayList(); - final List<TInner> rights = Lists.newArrayList(); - boolean done; - Enumerator<List<Object>> cartesians; - - { - start(); - } - - private void start() { - if (!leftEnumerator.moveNext() - || !rightEnumerator.moveNext() - || !advance()) { - done = true; - cartesians = Linq4j.emptyEnumerator(); - } - } - - /** Moves to the next key that is present in both sides. Populates - * lefts and rights with the rows. Restarts the cross-join - * enumerator. */ - private boolean advance() { - TSource left = leftEnumerator.current(); - TKey leftKey = outerKeySelector.apply(left); - TInner right = rightEnumerator.current(); - TKey rightKey = innerKeySelector.apply(right); - for (;;) { - int c = leftKey.compareTo(rightKey); - if (c == 0) { - break; - } - if (c < 0) { - if (!leftEnumerator.moveNext()) { - done = true; - return false; - } - left = leftEnumerator.current(); - leftKey = outerKeySelector.apply(left); - } else { - if (!rightEnumerator.moveNext()) { - done = true; - return false; - } - right = rightEnumerator.current(); - rightKey = innerKeySelector.apply(right); - } - } - lefts.clear(); - lefts.add(left); - for (;;) { - if (!leftEnumerator.moveNext()) { - done = true; - break; - } - left = leftEnumerator.current(); - TKey leftKey2 = outerKeySelector.apply(left); - int c = leftKey.compareTo(leftKey2); - if (c != 0) { - assert c < 0 : "not sorted"; - break; - } - lefts.add(left); - } - rights.clear(); - rights.add(right); - for (;;) { - if (!rightEnumerator.moveNext()) { - done = true; - break; - } - right = rightEnumerator.current(); - TKey rightKey2 = innerKeySelector.apply(right); - int c = rightKey.compareTo(rightKey2); - if (c != 0) { - assert c < 0 : "not sorted"; - break; - } - rights.add(right); - } - cartesians = Linq4j.product( - ImmutableList.of(Linq4j.<Object>enumerator(lefts), - Linq4j.<Object>enumerator(rights))); - return true; - } - - public TResult current() { - final List<Object> list = cartesians.current(); - @SuppressWarnings("unchecked") final TSource left = - (TSource) list.get(0); - @SuppressWarnings("unchecked") final TInner right = - (TInner) list.get(1); - return resultSelector.apply(left, right); - } - - public boolean moveNext() { - for (;;) { - if (cartesians.moveNext()) { - return true; - } - if (done) { - return false; - } - if (!advance()) { - return false; - } - } - } - - public void reset() { - done = false; - leftEnumerator.reset(); - rightEnumerator.reset(); - start(); - } - - public void close() { - } - }; - } - }; - } - } // End Enumerables.java http://git-wip-us.apache.org/repos/asf/calcite/blob/b0a996b4/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 94356a4..53f7880 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -135,14 +135,15 @@ public enum BuiltInMethod { String.class, Function1.class), JOIN(ExtendedEnumerable.class, "join", Enumerable.class, Function1.class, Function1.class, Function2.class), - MERGE_JOIN(Enumerables.class, "mergeJoin", Enumerable.class, Enumerable.class, - Function1.class, Function1.class, Function2.class, boolean.class, - boolean.class), + MERGE_JOIN(EnumerableDefaults.class, "mergeJoin", Enumerable.class, + Enumerable.class, Function1.class, Function1.class, Function2.class, + boolean.class, boolean.class), SLICE0(Enumerables.class, "slice0", Enumerable.class), - SEMI_JOIN(Enumerables.class, "semiJoin", Enumerable.class, Enumerable.class, - Function1.class, Function1.class), - THETA_JOIN(Enumerables.class, "thetaJoin", Enumerable.class, Enumerable.class, - Predicate2.class, Function2.class, boolean.class, boolean.class), + SEMI_JOIN(EnumerableDefaults.class, "semiJoin", Enumerable.class, + Enumerable.class, Function1.class, Function1.class), + THETA_JOIN(EnumerableDefaults.class, "thetaJoin", Enumerable.class, + Enumerable.class, Predicate2.class, Function2.class, boolean.class, + boolean.class), CORRELATE_JOIN(ExtendedEnumerable.class, "correlateJoin", CorrelateJoinType.class, Function1.class, Function2.class), SELECT(ExtendedEnumerable.class, "select", Function1.class), http://git-wip-us.apache.org/repos/asf/calcite/blob/b0a996b4/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java index 33a2593..5e76578 100644 --- a/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java +++ b/core/src/test/java/org/apache/calcite/runtime/EnumerablesTest.java @@ -17,6 +17,7 @@ package org.apache.calcite.runtime; import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.EnumerableDefaults; import org.apache.calcite.linq4j.Linq4j; import org.apache.calcite.linq4j.function.Function1; import org.apache.calcite.linq4j.function.Function2; @@ -69,7 +70,7 @@ public class EnumerablesTest { @Test public void testSemiJoin() { assertThat( - Enumerables.semiJoin(EMPS, DEPTS, + EnumerableDefaults.semiJoin(EMPS, DEPTS, new Function1<Emp, Integer>() { public Integer apply(Emp a0) { return a0.deptno; @@ -86,7 +87,7 @@ public class EnumerablesTest { @Test public void testMergeJoin() { assertThat( - Enumerables.mergeJoin( + EnumerableDefaults.mergeJoin( Linq4j.asEnumerable( Arrays.asList( new Emp(10, "Fred"), @@ -171,7 +172,7 @@ public class EnumerablesTest { private static <T extends Comparable<T>> Enumerable<T> intersect( List<T> list0, List<T> list1) { - return Enumerables.mergeJoin( + return EnumerableDefaults.mergeJoin( Linq4j.asEnumerable(list0), Linq4j.asEnumerable(list1), Functions.<T>identitySelector(), @@ -185,31 +186,31 @@ public class EnumerablesTest { @Test public void testThetaJoin() { assertThat( - Enumerables.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, EMP_DEPT_TO_STRING, - false, false).toList().toString(), + EnumerableDefaults.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, + EMP_DEPT_TO_STRING, false, false).toList().toString(), equalTo("[{Theodore, 20, 20, Sales}, {Sebastian, 20, 20, Sales}]")); } @Test public void testThetaLeftJoin() { assertThat( - Enumerables.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, EMP_DEPT_TO_STRING, - false, true).toList().toString(), + EnumerableDefaults.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, + EMP_DEPT_TO_STRING, false, true).toList().toString(), equalTo("[{Fred, 10, null, null}, {Theodore, 20, 20, Sales}, " + "{Sebastian, 20, 20, Sales}, {Joe, 30, null, null}]")); } @Test public void testThetaRightJoin() { assertThat( - Enumerables.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, EMP_DEPT_TO_STRING, - true, false).toList().toString(), + EnumerableDefaults.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, + EMP_DEPT_TO_STRING, true, false).toList().toString(), equalTo("[{Theodore, 20, 20, Sales}, {Sebastian, 20, 20, Sales}, " + "{null, null, 15, Marketing}]")); } @Test public void testThetaFullJoin() { assertThat( - Enumerables.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, EMP_DEPT_TO_STRING, - true, true).toList().toString(), + EnumerableDefaults.thetaJoin(EMPS, DEPTS, EQUAL_DEPTNO, + EMP_DEPT_TO_STRING, true, true).toList().toString(), equalTo("[{Fred, 10, null, null}, {Theodore, 20, 20, Sales}, " + "{Sebastian, 20, 20, Sales}, {Joe, 30, null, null}, " + "{null, null, 15, Marketing}]")); @@ -217,7 +218,7 @@ public class EnumerablesTest { @Test public void testThetaFullJoinLeftEmpty() { assertThat( - Enumerables.thetaJoin(EMPS.take(0), DEPTS, EQUAL_DEPTNO, + EnumerableDefaults.thetaJoin(EMPS.take(0), DEPTS, EQUAL_DEPTNO, EMP_DEPT_TO_STRING, true, true) .orderBy(Functions.<String>identitySelector()).toList().toString(), equalTo("[{null, null, 15, Marketing}, {null, null, 20, Sales}]")); @@ -225,7 +226,7 @@ public class EnumerablesTest { @Test public void testThetaFullJoinRightEmpty() { assertThat( - Enumerables.thetaJoin(EMPS, DEPTS.take(0), EQUAL_DEPTNO, + EnumerableDefaults.thetaJoin(EMPS, DEPTS.take(0), EQUAL_DEPTNO, EMP_DEPT_TO_STRING, true, true).toList().toString(), equalTo("[{Fred, 10, null, null}, {Theodore, 20, null, null}, " + "{Sebastian, 20, null, null}, {Joe, 30, null, null}]")); @@ -233,7 +234,7 @@ public class EnumerablesTest { @Test public void testThetaFullJoinBothEmpty() { assertThat( - Enumerables.thetaJoin(EMPS.take(0), DEPTS.take(0), EQUAL_DEPTNO, + EnumerableDefaults.thetaJoin(EMPS.take(0), DEPTS.take(0), EQUAL_DEPTNO, EMP_DEPT_TO_STRING, true, true).toList().toString(), equalTo("[]")); } http://git-wip-us.apache.org/repos/asf/calcite/blob/b0a996b4/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java ---------------------------------------------------------------------- diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 7024a86..dc3ef03 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -34,6 +34,10 @@ import org.apache.calcite.linq4j.function.NullableLongFunction1; import org.apache.calcite.linq4j.function.Predicate1; import org.apache.calcite.linq4j.function.Predicate2; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + import java.math.BigDecimal; import java.util.AbstractList; import java.util.AbstractMap; @@ -1292,6 +1296,114 @@ public abstract class EnumerableDefaults { } /** + * Returns elements of {@code outer} for which there is a member of + * {@code inner} with a matching key. + */ + public static <TSource, TInner, TKey> Enumerable<TSource> semiJoin( + final Enumerable<TSource> outer, final Enumerable<TInner> inner, + final Function1<TSource, TKey> outerKeySelector, + final Function1<TInner, TKey> innerKeySelector) { + return semiJoin(outer, inner, outerKeySelector, innerKeySelector, null); + } + + /** + * Returns elements of {@code outer} for which there is a member of + * {@code inner} with a matching key. A specified + * {@code EqualityComparer<TSource>} is used to compare keys. + */ + public static <TSource, TInner, TKey> Enumerable<TSource> semiJoin( + final Enumerable<TSource> outer, final Enumerable<TInner> inner, + final Function1<TSource, TKey> outerKeySelector, + final Function1<TInner, TKey> innerKeySelector, + final EqualityComparer<TKey> comparer) { + return new AbstractEnumerable<TSource>() { + public Enumerator<TSource> enumerator() { + final Enumerable<TKey> innerLookup = + comparer == null + ? inner.select(innerKeySelector).distinct() + : inner.select(innerKeySelector).distinct(comparer); + + return EnumerableDefaults.where(outer.enumerator(), + new Predicate1<TSource>() { + public boolean apply(TSource v0) { + final TKey key = outerKeySelector.apply(v0); + return innerLookup.contains(key); + } + }); + } + }; + } + + /** + * Correlates the elements of two sequences based on a predicate. + */ + public static <TSource, TInner, TResult> Enumerable<TResult> thetaJoin( + final Enumerable<TSource> outer, final Enumerable<TInner> inner, + final Predicate2<TSource, TInner> predicate, + Function2<TSource, TInner, TResult> resultSelector, + final boolean generateNullsOnLeft, + final boolean generateNullsOnRight) { + // Building the result as a list is easy but hogs memory. We should iterate. + final List<TResult> result = Lists.newArrayList(); + final Enumerator<TSource> lefts = outer.enumerator(); + final List<TInner> rightList = inner.toList(); + final Set<TInner> rightUnmatched; + if (generateNullsOnLeft) { + rightUnmatched = Sets.newIdentityHashSet(); + rightUnmatched.addAll(rightList); + } else { + rightUnmatched = null; + } + while (lefts.moveNext()) { + int leftMatchCount = 0; + final TSource left = lefts.current(); + final Enumerator<TInner> rights = Linq4j.iterableEnumerator(rightList); + while (rights.moveNext()) { + TInner right = rights.current(); + if (predicate.apply(left, right)) { + ++leftMatchCount; + if (rightUnmatched != null) { + rightUnmatched.remove(right); + } + result.add(resultSelector.apply(left, right)); + } + } + if (generateNullsOnRight && leftMatchCount == 0) { + result.add(resultSelector.apply(left, null)); + } + } + if (rightUnmatched != null) { + final Enumerator<TInner> rights = + Linq4j.iterableEnumerator(rightUnmatched); + while (rights.moveNext()) { + TInner right = rights.current(); + result.add(resultSelector.apply(null, right)); + } + } + return Linq4j.asEnumerable(result); + } + + /** Joins two inputs that are sorted on the key. */ + public static <TSource, TInner, TKey extends Comparable<TKey>, TResult> + Enumerable<TResult> mergeJoin(final Enumerable<TSource> outer, + final Enumerable<TInner> inner, + final Function1<TSource, TKey> outerKeySelector, + final Function1<TInner, TKey> innerKeySelector, + final Function2<TSource, TInner, TResult> resultSelector, + boolean generateNullsOnLeft, + boolean generateNullsOnRight) { + assert !generateNullsOnLeft : "not implemented"; + assert !generateNullsOnRight : "not implemented"; + return new AbstractEnumerable<TResult>() { + public Enumerator<TResult> enumerator() { + return new MergeJoinEnumerator<>(outer.enumerator(), + inner.enumerator(), outerKeySelector, innerKeySelector, + resultSelector); + } + }; + } + + /** * Returns the last element of a sequence that * satisfies a specified condition. */ @@ -2675,28 +2787,34 @@ public abstract class EnumerableDefaults { return new AbstractEnumerable<TSource>() { public Enumerator<TSource> enumerator() { final Enumerator<TSource> enumerator = source.enumerator(); - return new Enumerator<TSource>() { - public TSource current() { - return enumerator.current(); - } + return EnumerableDefaults.where(enumerator, predicate); + } + }; + } - public boolean moveNext() { - while (enumerator.moveNext()) { - if (predicate.apply(enumerator.current())) { - return true; - } - } - return false; - } + private static <TSource> Enumerator<TSource> where( + final Enumerator<TSource> enumerator, + final Predicate1<TSource> predicate) { + return new Enumerator<TSource>() { + public TSource current() { + return enumerator.current(); + } - public void reset() { - enumerator.reset(); + public boolean moveNext() { + while (enumerator.moveNext()) { + if (predicate.apply(enumerator.current())) { + return true; } + } + return false; + } - public void close() { - enumerator.close(); - } - }; + public void reset() { + enumerator.reset(); + } + + public void close() { + enumerator.close(); } }; } @@ -3075,6 +3193,142 @@ public abstract class EnumerableDefaults { }; } } + + /** Enumerator that performs a merge join on its sorted inputs. */ + private static class MergeJoinEnumerator<TResult, TSource, TInner, TKey extends Comparable<TKey>> + implements Enumerator<TResult> { + final List<TSource> lefts = new ArrayList<>(); + final List<TInner> rights = new ArrayList<>(); + private final Enumerator<TSource> leftEnumerator; + private final Enumerator<TInner> rightEnumerator; + private final Function1<TSource, TKey> outerKeySelector; + private final Function1<TInner, TKey> innerKeySelector; + private final Function2<TSource, TInner, TResult> resultSelector; + boolean done; + Enumerator<List<Object>> cartesians; + + MergeJoinEnumerator(Enumerator<TSource> leftEnumerator, + Enumerator<TInner> rightEnumerator, + Function1<TSource, TKey> outerKeySelector, + Function1<TInner, TKey> innerKeySelector, + Function2<TSource, TInner, TResult> resultSelector) { + this.leftEnumerator = leftEnumerator; + this.rightEnumerator = rightEnumerator; + this.outerKeySelector = outerKeySelector; + this.innerKeySelector = innerKeySelector; + this.resultSelector = resultSelector; + start(); + } + + private void start() { + if (!leftEnumerator.moveNext() + || !rightEnumerator.moveNext() + || !advance()) { + done = true; + cartesians = Linq4j.emptyEnumerator(); + } + } + + /** Moves to the next key that is present in both sides. Populates + * lefts and rights with the rows. Restarts the cross-join + * enumerator. */ + private boolean advance() { + TSource left = leftEnumerator.current(); + TKey leftKey = outerKeySelector.apply(left); + TInner right = rightEnumerator.current(); + TKey rightKey = innerKeySelector.apply(right); + for (;;) { + int c = leftKey.compareTo(rightKey); + if (c == 0) { + break; + } + if (c < 0) { + if (!leftEnumerator.moveNext()) { + done = true; + return false; + } + left = leftEnumerator.current(); + leftKey = outerKeySelector.apply(left); + } else { + if (!rightEnumerator.moveNext()) { + done = true; + return false; + } + right = rightEnumerator.current(); + rightKey = innerKeySelector.apply(right); + } + } + lefts.clear(); + lefts.add(left); + for (;;) { + if (!leftEnumerator.moveNext()) { + done = true; + break; + } + left = leftEnumerator.current(); + TKey leftKey2 = outerKeySelector.apply(left); + int c = leftKey.compareTo(leftKey2); + if (c != 0) { + assert c < 0 : "not sorted"; + break; + } + lefts.add(left); + } + rights.clear(); + rights.add(right); + for (;;) { + if (!rightEnumerator.moveNext()) { + done = true; + break; + } + right = rightEnumerator.current(); + TKey rightKey2 = innerKeySelector.apply(right); + int c = rightKey.compareTo(rightKey2); + if (c != 0) { + assert c < 0 : "not sorted"; + break; + } + rights.add(right); + } + cartesians = Linq4j.product( + ImmutableList.of(Linq4j.<Object>enumerator(lefts), + Linq4j.<Object>enumerator(rights))); + return true; + } + + public TResult current() { + final List<Object> list = cartesians.current(); + @SuppressWarnings("unchecked") final TSource left = + (TSource) list.get(0); + @SuppressWarnings("unchecked") final TInner right = + (TInner) list.get(1); + return resultSelector.apply(left, right); + } + + public boolean moveNext() { + for (;;) { + if (cartesians.moveNext()) { + return true; + } + if (done) { + return false; + } + if (!advance()) { + return false; + } + } + } + + public void reset() { + done = false; + leftEnumerator.reset(); + rightEnumerator.reset(); + start(); + } + + public void close() { + } + } } // End EnumerableDefaults.java
