Github user jaxony commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/167#discussion_r226817004
  
    --- Diff: core/src/main/java/hivemall/mf/CofactorizationUDTF.java ---
    @@ -0,0 +1,584 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.mf;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.annotations.VisibleForTesting;
    +import hivemall.common.ConversionState;
    +import hivemall.fm.Feature;
    +import hivemall.fm.StringFeature;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.io.FileUtils;
    +import hivemall.utils.io.NIOUtils;
    +import hivemall.utils.io.NioStatefulSegment;
    +import hivemall.utils.lang.NumberUtils;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.SizeOf;
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.*;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
    +import org.apache.hadoop.mapred.Counters;
    +import org.apache.hadoop.mapred.Reporter;
    +
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +import java.io.File;
    +import java.io.IOException;
    +import java.nio.ByteBuffer;
    +import java.util.ArrayList;
    +import java.util.List;
    +
    +import static hivemall.utils.lang.Primitives.FALSE_BYTE;
    +import static hivemall.utils.lang.Primitives.TRUE_BYTE;
    +
    +public class CofactorizationUDTF extends UDTFWithOptions {
    +    private static final Log LOG = 
LogFactory.getLog(CofactorizationUDTF.class);
    +
    +    // Option variables
    +    // The number of latent factors
    +    private int factor;
    +    // The scaling hyperparameter for zero entries in the rank matrix
    +    private float scale_zero;
    +    // The scaling hyperparameter for non-zero entries in the rank matrix
    +    private float scale_nonzero;
    +    // The preferred size of the miniBatch for training
    +    private int batchSize;
    +    // The initial mean rating
    +    private float globalBias;
    +    // Whether update (and return) the mean rating or not
    +    private boolean updateGlobalBias;
    +    // The number of iterations
    +    private int maxIters;
    +    // Whether to use bias clause
    +    private boolean useBiasClause;
    +    // Whether to use normalization
    +    private boolean useL2Norm;
    +    // regularization hyperparameters
    +    private float lambdaTheta;
    +    private float lambdaBeta;
    +    private float lambdaGamma;
    +
    +    // Initialization strategy of rank matrix
    +    private CofactorModel.RankInitScheme rankInit;
    +
    +    // Model itself
    +    private CofactorModel model;
    +    private int numItems;
    +
    +    // Variable managing status of learning
    +
    +    // The number of processed training examples
    +    private long count;
    +
    +    private ConversionState cvState;
    +    private ConversionState validationState;
    +
    +    // Input OIs and Context
    +    private StringObjectInspector contextOI;
    +    @VisibleForTesting
    +    protected ListObjectInspector featuresOI;
    +    private BooleanObjectInspector isItemOI;
    +    @VisibleForTesting
    +    protected ListObjectInspector sppmiOI;
    +
    +    // Used for iterations
    +    @VisibleForTesting
    +    protected NioStatefulSegment fileIO;
    +    private ByteBuffer inputBuf;
    +    private long lastWritePos;
    +
    +    private String contextProbe;
    +    private Feature[] featuresProbe;
    +    private Feature[] sppmiProbe;
    +    private boolean isItemProbe;
    +    private long numValidations;
    +    private long numTraining;
    +
    +
    +    static class MiniBatch {
    +        private List<TrainingSample> users;
    +        private List<TrainingSample> items;
    +
    +        protected MiniBatch() {
    +            users = new ArrayList<>();
    +            items = new ArrayList<>();
    +        }
    +
    +        protected void add(TrainingSample sample) {
    +            if (sample.isItem()) {
    +                items.add(sample);
    +            } else {
    +                users.add(sample);
    +            }
    +        }
    +
    +        protected void clear() {
    +            users.clear();
    +            items.clear();
    +        }
    +
    +        protected int size() {
    +            return items.size() + users.size();
    +        }
    +
    +        protected List<TrainingSample> getItems() {
    +            return items;
    +        }
    +
    +        protected List<TrainingSample> getUsers() {
    +            return users;
    +        }
    +    }
    +
    +    static class TrainingSample {
    +        protected String context;
    +        protected Feature[] features;
    +        protected Feature[] sppmi;
    +
    +        protected TrainingSample(String context, Feature[] features, 
Feature[] sppmi) {
    +            this.context = context;
    +            this.features = features;
    +            this.sppmi = sppmi;
    +        }
    +
    +        protected boolean isItem() {
    +            return sppmi != null;
    +        }
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("k", "factor", true, "The number of latent factor 
[default: 10] "
    +                + " Note this is alias for `factors` option.");
    +        opts.addOption("f", "factors", true, "The number of latent factor 
[default: 10]");
    +        opts.addOption("r", "lambda", true, "The regularization factor 
[default: 0.03]");
    +        opts.addOption("c0", "scale_zero", true,
    +                "The scaling hyperparameter for zero entries in the rank 
matrix [default: 0.1]");
    +        opts.addOption("c1", "scale_nonzero", true,
    +                "The scaling hyperparameter for non-zero entries in the 
rank matrix [default: 1.0]");
    +        opts.addOption("b", "batch_size", true, "The miniBatch size for 
training [default: 1024]");
    +        opts.addOption("n", "num_items", false, "Number of items");
    +        opts.addOption("gb", "global_bias", true, "The global bias 
[default: 0.0]");
    +        opts.addOption("update_gb", "update_gb", false,
    +                "Whether update (and return) the global bias or not");
    +        opts.addOption("rankinit", true,
    +                "Initialization strategy of rank matrix [random, gaussian] 
(default: gaussian)");
    +        opts.addOption("maxval", "max_init_value", true,
    +                "The maximum initial value in the rank matrix [default: 
1.0]");
    +        opts.addOption("min_init_stddev", true,
    +                "The minimum standard deviation of initial rank matrix 
[default: 0.01]");
    +        opts.addOption("iters", "iterations", true, "The number of 
iterations [default: 1]");
    +        opts.addOption("iter", true,
    +                "The number of iterations [default: 1] Alias for 
`-iterations`");
    +        opts.addOption("disable_cv", "disable_cvtest", false,
    +                "Whether to disable convergence check [default: enabled]");
    +        opts.addOption("cv_rate", "convergence_rate", true,
    +                "Threshold to determine convergence [default: 0.005]");
    +        opts.addOption("disable_bias", "no_bias", false, "Turn off bias 
clause");
    +        // feature representation
    +        opts.addOption("int_feature", "feature_as_integer", false,
    +                "Parse a feature as integer [default: OFF]");
    +        // normalization
    +        opts.addOption("disable_norm", "disable_l2norm", false, "Disable 
instance-wise L2 normalization");
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
    +        CommandLine cl = null;
    +        String rankInitOpt = "gaussian";
    +        float maxInitValue = 1.f;
    +        double initStdDev = 0.1d;
    +        boolean conversionCheck = true;
    +        double convergenceRate = 0.005d;
    +
    +        if (argOIs.length >= 5) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[4]);
    +            cl = parseOptions(rawArgs);
    +            if (cl.hasOption("factors")) {
    +                this.factor = 
Primitives.parseInt(cl.getOptionValue("factors"), 10);
    +            } else {
    +                this.factor = 
Primitives.parseInt(cl.getOptionValue("factor"), 10);
    +            }
    +            this.lambdaTheta = 
Primitives.parseFloat(cl.getOptionValue("lambda_theta"), 1e-5f);
    +            this.lambdaBeta = 
Primitives.parseFloat(cl.getOptionValue("lambda_beta"), 1e-5f);
    +            this.lambdaGamma = 
Primitives.parseFloat(cl.getOptionValue("lambda_gamma"), 1e+0f);
    +            this.scale_zero = 
Primitives.parseFloat(cl.getOptionValue("scale_zero"), 0.1f);
    +            this.scale_nonzero = 
Primitives.parseFloat(cl.getOptionValue("scale_nonzero"), 1.0f);
    +            this.batchSize = 
Primitives.parseInt(cl.getOptionValue("batch_size"), 1024);
    +            if (cl.hasOption("num_items")) {
    +                this.numItems = 
Primitives.parseInt(cl.getOptionValue("num_items"), 1024);
    +            } else {
    +                throw new UDFArgumentException("-num_items must be 
specified");
    +            }
    +            this.globalBias = 
Primitives.parseFloat(cl.getOptionValue("gb"), 0.f);
    +            this.updateGlobalBias = cl.hasOption("update_gb");
    +            rankInitOpt = cl.getOptionValue("rankinit");
    +            maxInitValue = 
Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.f);
    +            initStdDev = 
Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.01d);
    +            if (cl.hasOption("iter")) {
    +                this.maxIters = 
Primitives.parseInt(cl.getOptionValue("iter"), 1);
    +            } else {
    +                this.maxIters = 
Primitives.parseInt(cl.getOptionValue("maxIters"), 1);
    +            }
    +            if (maxIters < 1) {
    +                throw new UDFArgumentException(
    +                        "'-maxIters' must be greater than or equal to 1: " 
+ maxIters);
    +            }
    +            conversionCheck = !cl.hasOption("disable_cvtest");
    +            convergenceRate = 
Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
    +            boolean noBias = cl.hasOption("no_bias");
    +            this.useBiasClause = !noBias;
    +            if (noBias && updateGlobalBias) {
    +                throw new UDFArgumentException(
    +                        "Cannot set both `update_gb` and `no_bias` 
option");
    +            }
    +            this.useL2Norm = !cl.hasOption("disable_l2norm");
    +        }
    +        this.rankInit = CofactorModel.RankInitScheme.resolve(rankInitOpt);
    +        rankInit.setMaxInitValue(maxInitValue);
    +        initStdDev = Math.max(initStdDev, 1.0d / factor);
    +        rankInit.setInitStdDev(initStdDev);
    +        this.cvState = new ConversionState(conversionCheck, 
convergenceRate);
    +        return cl;
    +    }
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) 
throws UDFArgumentException {
    +        if (argOIs.length < 3) {
    +            throw new UDFArgumentException(
    +                    "_FUNC_ takes 5 arguments: string context, 
array<string> features, boolean is_item, array<string> sppmi [, CONSTANT STRING 
options]");
    +        }
    +        this.contextOI = HiveUtils.asStringOI(argOIs[0]);
    +        this.featuresOI = HiveUtils.asListOI(argOIs[1]);
    +        
HiveUtils.validateFeatureOI(featuresOI.getListElementObjectInspector());
    +        this.isItemOI = HiveUtils.asBooleanOI(argOIs[2]);
    +        this.sppmiOI = HiveUtils.asListOI(argOIs[3]);
    +        
HiveUtils.validateFeatureOI(sppmiOI.getListElementObjectInspector());
    +
    +        processOptions(argOIs);
    +
    +        this.model = new CofactorModel(factor, rankInit, scale_zero, 
scale_nonzero, lambdaTheta, lambdaBeta, lambdaGamma);
    +        this.count = 0L;
    +        this.lastWritePos = 0L;
    +
    +        List<String> fieldNames = new ArrayList<String>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
    +        fieldNames.add("idx");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("Pu");
    +        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
    +                
PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
    +        fieldNames.add("Qi");
    +        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
    +                
PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
    +        return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (args.length != 4) {
    +            throw new HiveException("should have 4 args, but have " + 
args.length);
    +        }
    +
    +        String context = contextOI.getPrimitiveJavaObject(args[0]);
    +        Feature[] features = parseFeatures(args[1], featuresOI, 
featuresProbe);
    +        if (features == null) {
    +            throw new HiveException("features must not be null");
    +        }
    +
    +        Boolean isItem = isItemOI.get(args[2]);
    +        Feature[] sppmi = null;
    +        if (isItem) {
    +            sppmi = parseFeatures(args[3], sppmiOI, sppmiProbe);
    +        }
    +
    +        model.recordContext(context, isItem);
    +
    +        this.contextProbe = context;
    +        this.featuresProbe = features;
    +        this.isItemProbe = isItem;
    +        this.sppmiProbe = sppmi;
    +
    +        recordTrain(context, features, sppmi);
    +    }
    +
    +    @Nullable
    +    @VisibleForTesting
    +    protected static Feature[] parseFeatures(@Nullable final Object arg, 
ListObjectInspector listOI, Feature[] probe) throws HiveException {
    +        if (arg == null) {
    +            return null;
    +        }
    +        Feature[] rawFeatures = Feature.parseFeatures(arg, listOI, probe, 
false);
    +        return createNnzFeatureArray(rawFeatures);
    +    }
    +
    +    @VisibleForTesting
    +    protected static Feature[] createNnzFeatureArray(@Nonnull Feature[] x) 
{
    +        int nnz = countNnzFeatures(x);
    +        Feature[] nnzFeatures = new Feature[nnz];
    +        int i = 0;
    +        for (Feature f : x) {
    +            if (f.getValue() != 0.d) {
    +                nnzFeatures[i++] = f;
    +            }
    +        }
    +        return nnzFeatures;
    +    }
    +
    +    private static int countNnzFeatures(@Nonnull Feature[] x) {
    +        int nnz = 0;
    +        for (Feature f : x) {
    +            if (f.getValue() != 0.d) {
    +                nnz++;
    +            }
    +        }
    +        return nnz;
    +    }
    +
    +    private Double trainMiniBatch(MiniBatch miniBatch) throws 
HiveException {
    +        model.updateWithUsers(miniBatch.getUsers());
    +        model.updateWithItems(miniBatch.getItems());
    +        return model.calculateLoss(miniBatch.getUsers(), 
miniBatch.getItems());
    +    }
    +
    +    private void recordTrain(final String context, final Feature[] 
features, final Feature[] sppmi)
    +            throws HiveException {
    +        numTraining++;
    +        ByteBuffer inputBuf = this.inputBuf;
    +        NioStatefulSegment dst = this.fileIO;
    +        if (inputBuf == null) {
    +            final File file = createTempFile();
    +            this.inputBuf = inputBuf = ByteBuffer.allocateDirect(1024 * 
1024); // 1 MiB
    +            this.fileIO = dst = new NioStatefulSegment(file, false);
    +        }
    +
    +        writeRecordToBuffer(inputBuf, dst, context, features, sppmi);
    +    }
    +
    +    private static void writeRecordToBuffer(@Nonnull final ByteBuffer 
inputBuf, @Nonnull final NioStatefulSegment dst, @Nonnull final String context,
    +                                            @Nonnull final Feature[] 
features, @Nullable final Feature[] sppmi) throws HiveException {
    +        int recordBytes = calculateRecordBytes(context, features, sppmi);
    +        int requiredBytes = SizeOf.INT + recordBytes;
    +        int remain = inputBuf.remaining();
    +
    +        if (remain < requiredBytes) {
    +            writeBuffer(inputBuf, dst);
    +        }
    +
    +        inputBuf.putInt(recordBytes);
    +        NIOUtils.putString(context, inputBuf);
    +        writeFeaturesToBuffer(features, inputBuf);
    +        if (sppmi != null) {
    +            inputBuf.put(TRUE_BYTE);
    +            writeFeaturesToBuffer(sppmi, inputBuf);
    +        } else {
    +            inputBuf.put(FALSE_BYTE);
    +        }
    +    }
    +
    +    private static int calculateRecordBytes(String context, Feature[] 
features, Feature[] sppmi) {
    +        int contextBytes = SizeOf.INT + SizeOf.CHAR * context.length();
    +        int featuresBytes = SizeOf.INT + Feature.requiredBytes(features);
    +        int isItemBytes = SizeOf.BYTE;
    +        int sppmiBytes = sppmi != null ? SizeOf.INT + 
Feature.requiredBytes(sppmi) : 0;
    +        return contextBytes + featuresBytes + isItemBytes + sppmiBytes;
    +    }
    +
    +    private static File createTempFile() throws UDFArgumentException {
    +        final File file;
    +        try {
    +            file = File.createTempFile("hivemall_cofactor", ".sgmt");
    +            file.deleteOnExit();
    +            if (!file.canWrite()) {
    +                throw new UDFArgumentException(
    +                        "Cannot write a temporary file: " + 
file.getAbsolutePath());
    +            }
    +            LOG.info("Record training examples to a file: " + 
file.getAbsolutePath());
    +        } catch (IOException ioe) {
    +            throw new UDFArgumentException(ioe);
    +        } catch (Throwable e) {
    +            throw new UDFArgumentException(e);
    +        }
    +        return file;
    +    }
    +
    +    private static void writeFeaturesToBuffer(Feature[] features, 
ByteBuffer buffer) {
    +        buffer.putInt(features.length);
    +        for (Feature f : features) {
    +            f.writeTo(buffer);
    +        }
    +    }
    +
    +    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull 
NioStatefulSegment dst)
    +            throws HiveException {
    +        srcBuf.flip();
    +        try {
    +            dst.write(srcBuf);
    +        } catch (IOException e) {
    +            throw new HiveException("Exception causes while writing a 
buffer to file", e);
    +        }
    +        srcBuf.clear();
    +    }
    +
    +    @Override
    +    public void close() throws HiveException {
    +        try {
    +            boolean lossIncreasedLastIter = false;
    +
    +            final Reporter reporter = getReporter();
    +            final Counters.Counter iterCounter = (reporter == null) ? null
    +                    : reporter.getCounter("hivemall.mf.Cofactor$Counter", 
"iteration");
    +
    +            prepareForRead();
    +
    +            if (LOG.isInfoEnabled()) {
    +                File tmpFile = fileIO.getFile();
    +                LOG.info("Wrote " + numTraining
    +                        + " records to a temporary file for iterative 
training: "
    +                        + tmpFile.getAbsolutePath() + " (" + 
FileUtils.prettyFileSize(tmpFile)
    +                        + ")");
    +            }
    +
    +            for (int iteration = 0; iteration < maxIters; iteration++) {
    +                // train the model on a full batch (i.e., all the data) 
using mini-batch updates
    +//                validationState.next();
    +                cvState.next();
    +                reportProgress(reporter);
    +                setCounterValue(iterCounter, iteration);
    +                runTrainingIteration();
    +
    +                LOG.info("Performed " + cvState.getCurrentIteration() + " 
iterations of "
    +                        + NumberUtils.formatNumber(maxIters));
    +//                        + " training examples on a secondary storage 
(thus "
    +//                        + NumberUtils.formatNumber(_t) + " training 
updates in total), used "
    +//                        + _numValidations + " validation examples");
    +            }
    +        } finally {
    +            // delete the temporary file and release resources
    +            try {
    +                fileIO.close(true);
    +            } catch (IOException e) {
    +                throw new HiveException(
    +                        "Failed to close a file: " + 
fileIO.getFile().getAbsolutePath(), e);
    +            }
    +            this.inputBuf = null;
    +            this.fileIO = null;
    +        }
    +    }
    +
    +    @VisibleForTesting
    +    protected void prepareForRead() throws HiveException {
    +        // write training examples in buffer to a temporary file
    +        if (inputBuf.remaining() > 0) {
    +            writeBuffer(inputBuf, fileIO);
    +        }
    +        try {
    +            fileIO.flush();
    +        } catch (IOException e) {
    +            throw new HiveException(
    +                    "Failed to flush a file: " + 
fileIO.getFile().getAbsolutePath(), e);
    +        }
    +        fileIO.resetPosition();
    +    }
    +
    +    private void runTrainingIteration() throws HiveException {
    +        fileIO.resetPosition();
    +        MiniBatch miniBatch = new MiniBatch();
    +        // read minibatch from disk into memory
    +        while (readMiniBatchFromFile(miniBatch)) {
    +            Double trainLoss = trainMiniBatch(miniBatch);
    +            if (trainLoss != null) {
    +                cvState.incrLoss(trainLoss);
    +            }
    +            miniBatch.clear();
    +        }
    +
    +    }
    +
    +    @Nonnull
    +    private static Feature instantiateFeature(@Nonnull final ByteBuffer 
input) {
    +        return new StringFeature(input);
    +    }
    +
    +    @VisibleForTesting
    +    protected boolean readMiniBatchFromFile(MiniBatch miniBatch) throws 
HiveException {
    +        inputBuf.clear();
    --- End diff --
    
    Fix this bug: remove the line


---

Reply via email to