zhipeng93 commented on code in PR #155:
URL: https://github.com/apache/flink-ml/pull/155#discussion_r975019713


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansionParams.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.polynomialexpansion;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link PolynomialExpansion}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface PolynomialExpansionParams<T> extends HasInputCol<T>, 
HasOutputCol<T> {
+    Param<Integer> DEGREE =
+            new IntParam(
+                    "degree", "Degree of the polynomial expansion.", 2, 
ParamValidators.gtEq(1));
+
+    default Integer getDegree() {

Review Comment:
   How about we return `int` instead of `Integer`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java:
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.polynomialexpansion;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.util.ArithmeticUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Transformer that expands the input vectors in polynomial space.
+ *
+ * <p>Take a 2-dimension vector as an example: `(x, y)`, if we want to expand 
it with degree 2, then
+ * we get `(x, x * x, y, x * y, y * y)`.
+ *
+ * <p>For more information about the polynomial expansion, see
+ * http://en.wikipedia.org/wiki/Polynomial_expansion.
+ */
+public class PolynomialExpansion
+        implements Transformer<PolynomialExpansion>,
+                PolynomialExpansionParams<PolynomialExpansion> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public PolynomialExpansion() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new PolynomialExpansionFunction(getDegree(), 
getInputCol()),
+                                outputTypeInfo);
+
+        Table outputTable = tEnv.fromDataStream(output);
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static PolynomialExpansion load(StreamTableEnvironment env, String 
path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Polynomial expansion function that expands a vector in polynomial 
space. This expansion is
+     * done using recursion. Given input vector and degree, the size after 
expansion is (vectorSize
+     * + degree) (including 1 and first-order values). For example, let f([a, 
b, c], 3) be the
+     * function that expands [a, b, c] to their monomials of degree 3. We have 
the following
+     * recursion:
+     *
+     * <blockquote>
+     *
+     * $$ f([a, b, c], 3) &= f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) 
* c^2 ++ [c^3] $$
+     *
+     * </blockquote>
+     *
+     * <p>To handle sparsity, if c is zero, we can skip all monomials that 
contain it. We remember
+     * the current index and increment it properly for sparse input.
+     */
+    private static class PolynomialExpansionFunction implements 
MapFunction<Row, Row> {
+        private final int degree;
+        private final String inputCol;
+
+        public PolynomialExpansionFunction(int degree, String inputCol) {
+            this.degree = degree;
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            Vector vec = row.getFieldAs(inputCol);
+            if (vec == null) {

Review Comment:
   Probably we should throw an exception when the `vec` is null. Now it cannot 
pass the unit test when the `vec` is null.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java:
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.polynomialexpansion;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.util.ArithmeticUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Transformer that expands the input vectors in polynomial space.
+ *
+ * <p>Take a 2-dimension vector as an example: `(x, y)`, if we want to expand 
it with degree 2, then
+ * we get `(x, x * x, y, x * y, y * y)`.
+ *
+ * <p>For more information about the polynomial expansion, see
+ * http://en.wikipedia.org/wiki/Polynomial_expansion.
+ */
+public class PolynomialExpansion
+        implements Transformer<PolynomialExpansion>,
+                PolynomialExpansionParams<PolynomialExpansion> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public PolynomialExpansion() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new PolynomialExpansionFunction(getDegree(), 
getInputCol()),
+                                outputTypeInfo);
+
+        Table outputTable = tEnv.fromDataStream(output);
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static PolynomialExpansion load(StreamTableEnvironment env, String 
path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Polynomial expansion function that expands a vector in polynomial 
space. This expansion is
+     * done using recursion. Given input vector and degree, the size after 
expansion is (vectorSize
+     * + degree) (including 1 and first-order values). For example, let f([a, 
b, c], 3) be the
+     * function that expands [a, b, c] to their monomials of degree 3. We have 
the following
+     * recursion:
+     *
+     * <blockquote>
+     *
+     * $$ f([a, b, c], 3) &= f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) 
* c^2 ++ [c^3] $$
+     *
+     * </blockquote>
+     *
+     * <p>To handle sparsity, if c is zero, we can skip all monomials that 
contain it. We remember
+     * the current index and increment it properly for sparse input.
+     */
+    private static class PolynomialExpansionFunction implements 
MapFunction<Row, Row> {
+        private final int degree;
+        private final String inputCol;
+
+        public PolynomialExpansionFunction(int degree, String inputCol) {
+            this.degree = degree;
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            Vector vec = row.getFieldAs(inputCol);
+            if (vec == null) {
+                return row;
+            }
+            Vector outputVec;
+            if (vec instanceof DenseVector) {
+                int size = vec.size();
+                double[] retVals = new double[getResultVectorSize(size, 
degree) - 1];
+                expandDenseVector(((DenseVector) vec).values, size - 1, 
degree, 1.0, retVals, -1);
+                outputVec = new DenseVector(retVals);
+            } else if (vec instanceof SparseVector) {
+                SparseVector sparseVec = (SparseVector) vec;
+                int[] indices = sparseVec.indices;
+                double[] values = sparseVec.values;
+                int size = sparseVec.size();
+                int nnz = sparseVec.values.length;
+                int nnzPolySize = getResultVectorSize(nnz, degree);
+
+                Tuple2<Integer, int[]> polyIndices = Tuple2.of(0, new 
int[nnzPolySize - 1]);
+                Tuple2<Integer, double[]> polyValues = Tuple2.of(0, new 
double[nnzPolySize - 1]);
+                expandSparseVector(
+                        indices,
+                        values,
+                        nnz - 1,
+                        size - 1,
+                        degree,
+                        1.0,
+                        polyIndices,
+                        polyValues,
+                        -1);
+
+                outputVec =
+                        new SparseVector(
+                                getResultVectorSize(size, degree) - 1,
+                                polyIndices.f1,
+                                polyValues.f1);
+            } else {
+                throw new UnsupportedOperationException(
+                        "Only supports DenseVector or SparseVector.");
+            }
+            return Row.join(row, Row.of(outputVec));
+        }
+
+        /** Calculates the length of the expended vector. */
+        private static int getResultVectorSize(int num, int degree) {
+            if (num == 0) {
+                return 1;
+            }
+
+            if (num == 1 || degree == 1) {
+                return num + degree;
+            }
+
+            if (degree > num) {
+                return getResultVectorSize(degree, num);
+            }
+
+            long res = 1;
+            int i = num + 1;
+            int j;
+
+            if (num + degree < 61) {
+                for (j = 1; j <= degree; ++j) {
+                    res = res * i / j;
+                    ++i;
+                }
+            } else {
+                int depth;
+                for (j = 1; j <= degree; ++j) {
+                    depth = ArithmeticUtils.gcd(i, j);
+                    res = ArithmeticUtils.mulAndCheck(res / (j / depth), i / 
depth);
+                    ++i;
+                }
+            }
+
+            if (res > Integer.MAX_VALUE) {
+                throw new RuntimeException("The expended polynomial size is 
too large.");
+            }
+            return (int) res;
+        }
+
+        /** Expands the dense vector in polynomial space. */
+        private int expandDenseVector(
+                double[] values,
+                int lastIdx,
+                int degree,
+                double factor,
+                double[] retValues,
+                int curPolyIdx) {
+            if (!Double.valueOf(factor).equals(0.0)) {
+                if (degree == 0 || lastIdx < 0) {
+                    if (curPolyIdx >= 0) {
+                        retValues[curPolyIdx] = factor;
+                    }
+                } else {
+                    double v = values[lastIdx];
+                    int newLastIdx = lastIdx - 1;
+                    double alpha = factor;
+                    int i = 0;
+                    int curStart = curPolyIdx;
+
+                    while (i <= degree && Math.abs(alpha) > 0.0) {
+                        curStart =
+                                expandDenseVector(
+                                        values, newLastIdx, degree - i, alpha, 
retValues, curStart);
+                        i += 1;
+                        alpha *= v;
+                    }
+                }
+            }
+            return curPolyIdx + getResultVectorSize(lastIdx + 1, degree);
+        }
+
+        /** Expands the sparse vector in polynomial space. */
+        private int expandSparseVector(

Review Comment:
   nit: this could be static.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java:
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link PolynomialExpansion}. */
+public class PolynomialExpansionTest extends AbstractTestBase {
+
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.dense(1.0, 2.0, 3.0),
+                            Vectors.sparse(5, new int[] {1, 4}, new double[] 
{2.0, 3.0})),
+                    Row.of(
+                            Vectors.dense(2.0, 3.0),
+                            Vectors.sparse(5, new int[] {1, 4}, new double[] 
{2.0, 1.0})));
+
+    private static final List<Vector> EXPECTED_DENSE_OUTPUT =
+            Arrays.asList(
+                    Vectors.dense(1.0, 1.0, 2.0, 2.0, 4.0, 3.0, 3.0, 6.0, 9.0),
+                    Vectors.dense(2.0, 4.0, 3.0, 6.0, 9.0));
+
+    private static final List<Vector> EXPECTED_DENSE_OUTPUT_WITH_DEGREE_3 =
+            Arrays.asList(
+                    Vectors.dense(
+                            1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 4.0, 4.0, 8.0, 3.0, 
3.0, 3.0, 6.0, 6.0,
+                            12.0, 9.0, 9.0, 18.0, 27.0),
+                    Vectors.dense(2.0, 4.0, 8.0, 3.0, 6.0, 12.0, 9.0, 18.0, 
27.0));
+
+    private static final List<Vector> EXPECTED_SPARSE_OUTPUT =
+            Arrays.asList(
+                    Vectors.sparse(
+                            55,
+                            new int[] {3, 6, 8, 34, 37, 39, 49, 51, 54},
+                            new double[] {2.0, 4.0, 8.0, 3.0, 6.0, 12.0, 9.0, 
18.0, 27.0}),
+                    Vectors.sparse(
+                            55,
+                            new int[] {3, 6, 8, 34, 37, 39, 49, 51, 54},
+                            new double[] {2.0, 4.0, 8.0, 1.0, 2.0, 4.0, 1.0, 
2.0, 1.0}));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+        inputDataTable = tEnv.fromDataStream(dataStream).as("denseVec", 
"sparseVec");
+    }
+
+    private void verifyOutputResult(Table output, String outputCol, 
List<Vector> expectedData)
+            throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+
+        List<Row> results = IteratorUtils.toList(stream.executeAndCollect());
+        List<Vector> resultVec = new ArrayList<>(results.size());
+        for (Row row : results) {
+            if (row.getField(outputCol) != null) {
+                resultVec.add(row.getFieldAs(outputCol));
+            }
+        }
+        compareResultCollections(expectedData, resultVec, TestUtils::compare);
+    }
+
+    @Test
+    public void testParam() {
+        PolynomialExpansion polynomialExpansion = new PolynomialExpansion();
+        assertEquals("input", polynomialExpansion.getInputCol());
+        assertEquals("output", polynomialExpansion.getOutputCol());
+        assertEquals(Integer.valueOf(2), polynomialExpansion.getDegree());

Review Comment:
   nit: It could be 2 if we return `int` instead of `Integer` in `getDegree()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java:
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.polynomialexpansion;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.util.ArithmeticUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Transformer that expands the input vectors in polynomial space.
+ *
+ * <p>Take a 2-dimension vector as an example: `(x, y)`, if we want to expand 
it with degree 2, then
+ * we get `(x, x * x, y, x * y, y * y)`.
+ *
+ * <p>For more information about the polynomial expansion, see
+ * http://en.wikipedia.org/wiki/Polynomial_expansion.
+ */
+public class PolynomialExpansion
+        implements Transformer<PolynomialExpansion>,
+                PolynomialExpansionParams<PolynomialExpansion> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public PolynomialExpansion() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new PolynomialExpansionFunction(getDegree(), 
getInputCol()),
+                                outputTypeInfo);
+
+        Table outputTable = tEnv.fromDataStream(output);
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static PolynomialExpansion load(StreamTableEnvironment env, String 
path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Polynomial expansion function that expands a vector in polynomial 
space. This expansion is
+     * done using recursion. Given input vector and degree, the size after 
expansion is (vectorSize
+     * + degree) (including 1 and first-order values). For example, let f([a, 
b, c], 3) be the
+     * function that expands [a, b, c] to their monomials of degree 3. We have 
the following
+     * recursion:
+     *
+     * <blockquote>
+     *
+     * $$ f([a, b, c], 3) &= f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) 
* c^2 ++ [c^3] $$
+     *
+     * </blockquote>
+     *
+     * <p>To handle sparsity, if c is zero, we can skip all monomials that 
contain it. We remember
+     * the current index and increment it properly for sparse input.
+     */
+    private static class PolynomialExpansionFunction implements 
MapFunction<Row, Row> {
+        private final int degree;
+        private final String inputCol;
+
+        public PolynomialExpansionFunction(int degree, String inputCol) {
+            this.degree = degree;
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            Vector vec = row.getFieldAs(inputCol);
+            if (vec == null) {
+                return row;
+            }
+            Vector outputVec;
+            if (vec instanceof DenseVector) {
+                int size = vec.size();
+                double[] retVals = new double[getResultVectorSize(size, 
degree) - 1];
+                expandDenseVector(((DenseVector) vec).values, size - 1, 
degree, 1.0, retVals, -1);
+                outputVec = new DenseVector(retVals);
+            } else if (vec instanceof SparseVector) {
+                SparseVector sparseVec = (SparseVector) vec;
+                int[] indices = sparseVec.indices;
+                double[] values = sparseVec.values;
+                int size = sparseVec.size();
+                int nnz = sparseVec.values.length;
+                int nnzPolySize = getResultVectorSize(nnz, degree);
+
+                Tuple2<Integer, int[]> polyIndices = Tuple2.of(0, new 
int[nnzPolySize - 1]);
+                Tuple2<Integer, double[]> polyValues = Tuple2.of(0, new 
double[nnzPolySize - 1]);
+                expandSparseVector(
+                        indices,
+                        values,
+                        nnz - 1,
+                        size - 1,
+                        degree,
+                        1.0,
+                        polyIndices,
+                        polyValues,
+                        -1);
+
+                outputVec =
+                        new SparseVector(
+                                getResultVectorSize(size, degree) - 1,
+                                polyIndices.f1,
+                                polyValues.f1);
+            } else {
+                throw new UnsupportedOperationException(
+                        "Only supports DenseVector or SparseVector.");
+            }
+            return Row.join(row, Row.of(outputVec));
+        }
+
+        /** Calculates the length of the expended vector. */
+        private static int getResultVectorSize(int num, int degree) {
+            if (num == 0) {
+                return 1;
+            }
+
+            if (num == 1 || degree == 1) {
+                return num + degree;
+            }
+
+            if (degree > num) {
+                return getResultVectorSize(degree, num);
+            }
+
+            long res = 1;
+            int i = num + 1;
+            int j;
+
+            if (num + degree < 61) {
+                for (j = 1; j <= degree; ++j) {
+                    res = res * i / j;
+                    ++i;
+                }
+            } else {
+                int depth;
+                for (j = 1; j <= degree; ++j) {
+                    depth = ArithmeticUtils.gcd(i, j);
+                    res = ArithmeticUtils.mulAndCheck(res / (j / depth), i / 
depth);
+                    ++i;
+                }
+            }
+
+            if (res > Integer.MAX_VALUE) {
+                throw new RuntimeException("The expended polynomial size is 
too large.");
+            }
+            return (int) res;
+        }
+
+        /** Expands the dense vector in polynomial space. */
+        private int expandDenseVector(

Review Comment:
   nit: this method could be static.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to