Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/111#discussion_r136904213
  
    --- Diff: core/src/main/java/hivemall/recommend/SlimUDTF.java ---
    @@ -0,0 +1,625 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.recommend;
    +
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.annotations.VisibleForTesting;
    +import hivemall.common.ConversionState;
    +import hivemall.math.matrix.sparse.DoKMatrix;
    +import hivemall.math.vector.VectorProcedure;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.io.FileUtils;
    +import hivemall.utils.io.NioStatefullSegment;
    +import hivemall.utils.lang.NumberUtils;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.SizeOf;
    +import hivemall.utils.lang.mutable.MutableDouble;
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.*;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.*;
    +import org.apache.hadoop.io.DoubleWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.mapred.Counters;
    +import org.apache.hadoop.mapred.Reporter;
    +
    +import javax.annotation.Nonnull;
    +import java.io.File;
    +import java.io.IOException;
    +import java.nio.ByteBuffer;
    +import java.util.*;
    +
    +
    +public class SlimUDTF extends UDTFWithOptions {
    +    private static final Log logger = LogFactory.getLog(SlimUDTF.class);
    +
    +    private double l1;
    +    private double l2;
    +    private int numIterations;
    +    private int previousItemId;
    +
    +    private transient DoKMatrix weightMatrix; // item-item weight matrix
    +    private transient DoKMatrix dataMatrix; // item-user matrix
    +
    +    private PrimitiveObjectInspector itemIOI;
    +    private PrimitiveObjectInspector itemJOI;
    +    private MapObjectInspector riOI;
    +    private MapObjectInspector rjOI;
    +
    +    private MapObjectInspector knnItemsOI;
    +    private PrimitiveObjectInspector knnItemsKeyOI;
    +    private MapObjectInspector knnItemsValueOI;
    +    private PrimitiveObjectInspector knnItemsValueKeyOI;
    +    private PrimitiveObjectInspector knnItemsValueValueOI;
    +
    +    private PrimitiveObjectInspector riKeyOI;
    +    private PrimitiveObjectInspector riValueOI;
    +
    +    private PrimitiveObjectInspector rjKeyOI;
    +    private PrimitiveObjectInspector rjValueOI;
    +
    +    // used to store KNN data into temporary file for iterative training
    +    private NioStatefullSegment fileIO;
    +    private ByteBuffer inputBuf;
    +
    +    private ConversionState cvState;
    +    private long observedTrainingExamples;
    +
    +    public SlimUDTF() {}
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) 
throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +        if (numArgs != 5 && numArgs != 6) {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes arguments: int i, map<int, double> r_i, 
map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j, [, 
constant string options]");
    +        }
    +
    +        this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
    +
    +        this.riOI = HiveUtils.asMapOI(argOIs[1]);
    +        this.riKeyOI = 
HiveUtils.asIntCompatibleOI((this.riOI.getMapKeyObjectInspector()));
    +        this.riValueOI = 
HiveUtils.asPrimitiveObjectInspector((this.riOI.getMapValueObjectInspector()));
    +
    +        this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
    +        this.knnItemsKeyOI = 
HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
    +        this.knnItemsValueOI = 
HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
    +        this.knnItemsValueKeyOI = 
HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
    +        this.knnItemsValueValueOI = 
HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
    +
    +        this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
    +
    +        this.rjOI = HiveUtils.asMapOI(argOIs[4]);
    +        this.rjKeyOI = 
HiveUtils.asIntCompatibleOI((this.rjOI.getMapKeyObjectInspector()));
    +        this.rjValueOI = 
HiveUtils.asPrimitiveObjectInspector((this.rjOI.getMapValueObjectInspector()));
    +
    +        processOptions(argOIs);
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("i");
    +        fieldNames.add("j");
    +        fieldNames.add("wij");
    +
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    +
    +        this.observedTrainingExamples = 0L;
    +        this.previousItemId = -2147483648;
    +
    +        this.dataMatrix = null;
    +        this.weightMatrix = null;
    +
    +        return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("l1", "l1coefficient", true,
    +            "Coefficient for l1 regularizer [default: 0.01]");
    +        opts.addOption("l2", "l2coefficient", true,
    +            "Coefficient for l2 regularizer [default: 0.01]");
    +        opts.addOption("numIterations", "iteration", true,
    +            "The number of iterations for coordinate descent [default: 
40]");
    +        opts.addOption("disable_cv", "disable_cvtest", false,
    +            "Whether to disable convergence check [default: enabled]");
    +        opts.addOption("cv_rate", "convergence_rate", true,
    +            "Threshold to determine convergence [default: 0.005]");
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
    +        CommandLine cl = null;
    +        double l1 = 0.01d;
    +        double l2 = 0.01d;
    +        int numIterations = 40;
    +        boolean conversionCheck = true;
    +        double cv_rate = 0.005d;
    +
    +        if (argOIs.length >= 6) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[5]);
    +            cl = parseOptions(rawArgs);
    +
    +            l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
    +            if (l1 < 0.d || l1 > 1.d) {
    --- End diff --
    
    `if (l1 < 0.d)` is enough.


---

Reply via email to