Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/167#discussion_r226243032 --- Diff: core/src/main/java/hivemall/mf/CofactorModel.java --- @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mf; + +import hivemall.fm.Feature; +import hivemall.utils.math.MathUtils; +import hivemall.utils.math.MatrixUtils; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.linear.SingularValueDecomposition; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.*; + +public class CofactorModel { + + public enum RankInitScheme { + random /* default */, gaussian; + + @Nonnegative + protected float maxInitValue; + @Nonnegative + protected double initStdDev; + + @Nonnull + public static CofactorModel.RankInitScheme resolve(@Nullable String opt) { + if (opt == null) { + return random; + } else if ("gaussian".equalsIgnoreCase(opt)) { + return gaussian; + } else if ("random".equalsIgnoreCase(opt)) { + return random; + } + return random; + } + + public void setMaxInitValue(float maxInitValue) { + this.maxInitValue = maxInitValue; + } + + public void setInitStdDev(double initStdDev) { + this.initStdDev = initStdDev; + } + + } + + private static final int EXPECTED_SIZE = 136861; + @Nonnegative + protected final int factor; + + // rank matrix initialization + protected final RankInitScheme initScheme; + + @Nonnull + private double globalBias; + + // storing trainable latent factors and weights + private Map<String, RealVector> theta; + private Map<String, RealVector> beta; + private Map<String, Double> betaBias; + private Map<String, RealVector> gamma; + private Map<String, Double> gammaBias; + + // precomputed identity matrix + private RealMatrix identity; + + protected final Random[] randU, randI; + + // hyperparameters + private final float c0, c1; + private final float lambdaTheta, lambdaBeta, lambdaGamma; + + public CofactorModel(@Nonnegative int factor, @Nonnull RankInitScheme initScheme, + @Nonnull float c0, @Nonnull float c1, float lambdaTheta, + float lambdaBeta, float lambdaGamma) { + + // rank init scheme is gaussian + // https://github.com/dawenl/cofactor/blob/master/src/cofacto.py#L98 + this.factor = factor; + this.initScheme = initScheme; + this.globalBias = 0.d; + this.lambdaTheta = lambdaTheta; + this.lambdaBeta = lambdaBeta; + this.lambdaGamma = lambdaGamma; + + this.theta = new HashMap<>(); + this.beta = new HashMap<>(); + this.betaBias = new HashMap<>(); + this.gamma = new HashMap<>(); + this.gammaBias = new HashMap<>(); + + this.randU = newRandoms(factor, 31L); + this.randI = newRandoms(factor, 41L); + + checkHyperparameterC(c0); + checkHyperparameterC(c1); + this.c0 = c0; + this.c1 = c1; + + } + + private void initFactorVector(String key, Map<String, RealVector> weights) { + if (weights.containsKey(key)) { + return; + } + RealVector v = new ArrayRealVector(factor); + switch (initScheme) { + case random: + uniformFill(v, randI[0], initScheme.maxInitValue); + break; + case gaussian: + gaussianFill(v, randI, initScheme.initStdDev); + break; + default: + throw new IllegalStateException( + "Unsupported rank initialization scheme: " + initScheme); + + } + weights.put(key, v); + } + + private static RealVector getFactorVector(String key, Map<String, RealVector> weights) { + return weights.get(key); + } + + private static void setFactorVector(String key, Map<String, RealVector> weights, RealVector factorVector) { + assert weights.containsKey(key); + weights.put(key, factorVector); + } + + private static double getBias(String key, Map<String, Double> biases) { + if (!biases.containsKey(key)) { --- End diff -- three hash lookup for worse case... ``` final Double v = biases.get(key); if(v == null) { return 0.d; } return v.doubleValue(); ``` Or, ```java private static double getBias(String key, Object2DoubleMap<String> biases) { // biases.defaultReturnValue(0.f); -- set in initialization return biases.get(key); } ```
---