Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/111#discussion_r136904213 --- Diff: core/src/main/java/hivemall/recommend/SlimUDTF.java --- @@ -0,0 +1,625 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.recommend; + + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; +import hivemall.math.matrix.sparse.DoKMatrix; +import hivemall.math.vector.VectorProcedure; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; +import hivemall.utils.lang.mutable.MutableDouble; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.*; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +import javax.annotation.Nonnull; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; + + +public class SlimUDTF extends UDTFWithOptions { + private static final Log logger = LogFactory.getLog(SlimUDTF.class); + + private double l1; + private double l2; + private int numIterations; + private int previousItemId; + + private transient DoKMatrix weightMatrix; // item-item weight matrix + private transient DoKMatrix dataMatrix; // item-user matrix + + private PrimitiveObjectInspector itemIOI; + private PrimitiveObjectInspector itemJOI; + private MapObjectInspector riOI; + private MapObjectInspector rjOI; + + private MapObjectInspector knnItemsOI; + private PrimitiveObjectInspector knnItemsKeyOI; + private MapObjectInspector knnItemsValueOI; + private PrimitiveObjectInspector knnItemsValueKeyOI; + private PrimitiveObjectInspector knnItemsValueValueOI; + + private PrimitiveObjectInspector riKeyOI; + private PrimitiveObjectInspector riValueOI; + + private PrimitiveObjectInspector rjKeyOI; + private PrimitiveObjectInspector rjValueOI; + + // used to store KNN data into temporary file for iterative training + private NioStatefullSegment fileIO; + private ByteBuffer inputBuf; + + private ConversionState cvState; + private long observedTrainingExamples; + + public SlimUDTF() {} + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + if (numArgs != 5 && numArgs != 6) { + throw new UDFArgumentException( + "_FUNC_ takes arguments: int i, map<int, double> r_i, map<int, map<int, double>> topKRatesOfI, int j, map<int, double> r_j, [, constant string options]"); + } + + this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]); + + this.riOI = HiveUtils.asMapOI(argOIs[1]); + this.riKeyOI = HiveUtils.asIntCompatibleOI((this.riOI.getMapKeyObjectInspector())); + this.riValueOI = HiveUtils.asPrimitiveObjectInspector((this.riOI.getMapValueObjectInspector())); + + this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]); + this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector()); + this.knnItemsValueOI = HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector()); + this.knnItemsValueKeyOI = HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector()); + this.knnItemsValueValueOI = HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector()); + + this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]); + + this.rjOI = HiveUtils.asMapOI(argOIs[4]); + this.rjKeyOI = HiveUtils.asIntCompatibleOI((this.rjOI.getMapKeyObjectInspector())); + this.rjValueOI = HiveUtils.asPrimitiveObjectInspector((this.rjOI.getMapValueObjectInspector())); + + processOptions(argOIs); + + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); + + fieldNames.add("i"); + fieldNames.add("j"); + fieldNames.add("wij"); + + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + this.observedTrainingExamples = 0L; + this.previousItemId = -2147483648; + + this.dataMatrix = null; + this.weightMatrix = null; + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("l1", "l1coefficient", true, + "Coefficient for l1 regularizer [default: 0.01]"); + opts.addOption("l2", "l2coefficient", true, + "Coefficient for l2 regularizer [default: 0.01]"); + opts.addOption("numIterations", "iteration", true, + "The number of iterations for coordinate descent [default: 40]"); + opts.addOption("disable_cv", "disable_cvtest", false, + "Whether to disable convergence check [default: enabled]"); + opts.addOption("cv_rate", "convergence_rate", true, + "Threshold to determine convergence [default: 0.005]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + double l1 = 0.01d; + double l2 = 0.01d; + int numIterations = 40; + boolean conversionCheck = true; + double cv_rate = 0.005d; + + if (argOIs.length >= 6) { + String rawArgs = HiveUtils.getConstString(argOIs[5]); + cl = parseOptions(rawArgs); + + l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1); + if (l1 < 0.d || l1 > 1.d) { --- End diff -- `if (l1 < 0.d)` is enough.
---