This is an automated email from the ASF dual-hosted git repository.
janardhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 60a2acbf4b [SYSTEMDS-3670] Added early-stopping mechanism to tSNE
(#1990)
60a2acbf4b is described below
commit 60a2acbf4bcfcb1e41a1c12be05504ffafa02720
Author: ramesesz <[email protected]>
AuthorDate: Fri Feb 9 06:05:43 2024 +0100
[SYSTEMDS-3670] Added early-stopping mechanism to tSNE (#1990)
This patch improves the builtin dist function by removing the outer product
operator. For 100 function calls on an arbitrary matrix with 4000 rows and 800
cols, the new dist function shortens the runtime from 66.541s to 60.268s.
---
scripts/builtin/tSNE.dml | 66 +++++++---
.../functions/builtin/part2/BuiltinTSNETest.java | 144 ++++++++++++++++++++-
src/test/scripts/functions/builtin/tSNE.dml | 2 +-
3 files changed, 191 insertions(+), 21 deletions(-)
diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml
index a28a1c1a0a..e4af10fbd8 100644
--- a/scripts/builtin/tSNE.dml
+++ b/scripts/builtin/tSNE.dml
@@ -41,9 +41,12 @@
# lr Learning rate
# momentum Momentum Parameter
# max_iter Number of iterations
+# tol Tolerance for early stopping in gradient descent
# seed The seed used for initial values.
# If set to -1 random seeds are selected.
# is_verbose Print debug information
+# print_iter Intervals of printing out the L1 norm values. Parameter not
relevant if
+# is_verbose = FALSE.
#
-------------------------------------------------------------------------------------------
#
# OUTPUT:
@@ -52,7 +55,8 @@
#
-------------------------------------------------------------------------------------------
m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer
perplexity = 30,
- Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Integer
seed = -1, Boolean is_verbose = FALSE)
+ Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol
= 1e-5,
+ Integer seed = -1, Boolean is_verbose = FALSE, Integer print_iter = 10)
return(Matrix[Double] Y)
{
d = reduced_dims
@@ -73,8 +77,41 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims =
2, Integer perplexity
if(is_verbose)
print("starting loop....")
- for (itr in 1:max_iter) {
- D = distance_matrix(Y)
+ itr = 1
+
+ # Start first iteration out of loop as benchmark for early stopping
+ D = dist(Y)
+ Z = 1/(D + 1)
+ Z = Z * ZERODIAG
+ Q = Z/sum(Z)
+ W = (P - Q)*Z
+ sumW = rowSums(W)
+ g = Y * sumW - W %*% Y
+ dY = momentum*dY - lr*g
+
+ norm = sum(dY^2)
+ norm_initial = norm
+ norm_target = norm_initial * tol
+
+ if(is_verbose){
+ print("L1 Norm initial : " + norm_initial)
+ print("L1 Norm target : " + norm_target)
+ }
+
+ Y = Y + dY
+ Y = Y - colMeans(Y)
+
+ if (itr%%100 == 0) {
+ C[itr/100,] = sum(P * log(pmax(P, 1e-12) / pmax(Q, 1e-12)))
+ }
+ if (itr == 100) {
+ P = P/4
+ }
+ itr = itr + 1
+ # End of first iteration
+
+ while (itr <= max_iter & norm > norm_target) {
+ D = dist(Y)
Z = 1/(D + 1)
Z = Z * ZERODIAG
Q = Z/sum(Z)
@@ -82,6 +119,13 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims =
2, Integer perplexity
sumW = rowSums(W)
g = Y * sumW - W %*% Y
dY = momentum*dY - lr*g
+
+ norm = sum(dY^2)
+ if(is_verbose & itr %% print_iter == 0){
+ print("Iteration: " + itr)
+ print("L1 Norm: " + norm)
+ }
+
Y = Y + dY
Y = Y - colMeans(Y)
@@ -91,20 +135,10 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims =
2, Integer perplexity
if (itr == 100) {
P = P/4
}
+ itr = itr + 1
}
}
-distance_matrix = function(matrix[double] X)
- return (matrix[double] out)
-{
- # TODO consolidate with dist() builtin, but with
- # better way of obtaining the diag from
- n = nrow(X)
- s = rowSums(X * X)
- out = - 2*X %*% t(X) + s + t(s)
-}
-
-
x2p = function(matrix[double] X, double perplexity, Boolean is_verbose = FALSE)
return(matrix[double] P)
{
@@ -115,7 +149,7 @@ return(matrix[double] P)
n = nrow(X)
if(is_verbose)
print(n)
- D = distance_matrix(X)
+ D = dist(X)
P = matrix(0, rows=n, cols=n)
beta = matrix(1, rows=n, cols=1)
@@ -129,7 +163,7 @@ return(matrix[double] P)
while (mean(abs(Hdiff)) > tol & itr < 50) {
P = exp(-D * beta)
P = P * ZERODIAG
- sum_Pi = rowSums(P)
+ sum_Pi = rowSums(P) + 1e-12
W = rowSums(P * D)
Ws = W/sum_Pi
H = log(sum_Pi) + beta * Ws
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
index 41d2d4606f..44d06ffaf9 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
@@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.junit.Test;
+import static org.junit.Assert.assertTrue;
import java.io.IOException;
@@ -41,12 +42,12 @@ public class BuiltinTSNETest extends AutomatedTestBase
@Test
public void testTSNECP() throws IOException {
runTSNETest(2, 30, 300.,
- 0.9, 1000, 42, "FALSE", ExecType.CP);
+ 0.9, 1000, 1e-5d, 42, "FALSE", 10, ExecType.CP);
}
@SuppressWarnings("unused")
- private void runTSNETest(Integer reduced_dims, Integer perplexity,
Double lr,
- Double momentum, Integer max_iter, Integer seed, String
is_verbose, ExecType instType)
+ private void runTSNETest(int reduced_dims, int perplexity, double lr,
+ double momentum, int max_iter, double tol, int seed, String
is_verbose, Integer print_iter, ExecType instType)
throws IOException
{
ExecMode platformOld = setExecMode(instType);
@@ -64,8 +65,11 @@ public class BuiltinTSNETest extends AutomatedTestBase
"lr=" + lr,
"momentum=" + momentum,
"max_iter=" + max_iter,
+ "tol=" + tol,
"seed=" + seed,
- "is_verbose=" + is_verbose};
+ "is_verbose=" + is_verbose,
+ "print_iter=" + print_iter
+ };
// The Input values are calculated using the following
R script:
// TODO create via dml operations, avoid inlining data
@@ -403,4 +407,136 @@ public class BuiltinTSNETest extends AutomatedTestBase
rtplatform = platformOld;
}
}
+
+
+ @Test
+ public void testTSNEEarlyStopping() throws IOException {
+ // Test setup guarantees early stopping.
+ runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1,
"TRUE", 10, ExecType.CP);
+ }
+
+ @SuppressWarnings("unused")
+ private void runTSNEEarlyStoppingTest(
+ Integer reduced_dims,
+ Integer perplexity,
+ Double lr,
+ Double momentum,
+ Integer max_iter,
+ Double tol,
+ Integer seed,
+ String is_verbose,
+ Integer print_iter,
+ ExecType instType) throws IOException {
+
+ ExecMode platformOld = setExecMode(instType);
+ try
+ {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{
+ "-nvargs", "X=" + input("X"), "Y=" +
output("Y"),
+ "reduced_dims=" + reduced_dims,
+ "perplexity=" + perplexity,
+ "lr=" + lr,
+ "momentum=" + momentum,
+ "max_iter=" + max_iter,
+ "tol=" + tol,
+ "seed=" + seed,
+ "is_verbose=" + is_verbose,
+ "print_iter=" + print_iter
+ };
+
+ // The Input values are calculated using the following
dml script:
+ // X = rand(rows=50, cols=2, min=0, max=5, seed=1)
+
+ // Input
+ double[][] X = {
+ {-0.45700987356506406, 2.834752454661148},
+ {-1.3945444464226533, -3.8794723634582597},
+ {-1.338576451510809, 1.160918547857504},
+ {2.921891699889728, -1.32749074959577},
+ {1.3001464754324763, -1.1353208514333533},
+ {0.2866401390950628, 3.359214248871961},
+ {2.6740553056629217, -1.2030274345674852},
+ {1.7240446900374895, 3.4430052477647557},
+ {-0.3435254305219493, 4.205393963204703},
+ {-2.873899183923896, 1.098272406118296},
+ {4.890217056606042, 1.5814575251762104},
+ {-4.920042511612875, 4.579455675519821},
+ {1.439881754507784, -4.090781835042895},
+ {2.32372435941579, 4.823050596338641},
+ {-0.9864739586714544, -1.6990853495458147},
+ {4.605792626050157, 2.411639339263437},
+ {4.979120527950069, 1.7181757158820465},
+ {-4.423608438974177, 0.44712526968937283},
+ {3.4109472479162317, -3.497269670333382},
+ {-1.9938801849366037, -1.1880069697833906},
+ {3.223381639747396, 3.7784510177449793},
+ {2.10470587687118, 0.5415570090498525},
+ {2.084254693325721, 1.4369473809787037},
+ {-0.9957311983302795, 1.586795215124286},
+ {-3.7527381124013894, 4.3818220996816475},
+ {3.5748622228245193, 1.116518048277384},
+ {-2.297351475873446, -2.0179124546489047},
+ {-0.3438938003649259, 0.689249021371154},
+ {-0.8823286368673617, 1.2731356499886672},
+ {2.517220722615252, -2.8806532181877254},
+ {3.923092638022041, 4.34404320783608},
+ {-2.1012040153953, -4.33147229525127},
+ {3.5992422607685715, 2.5628828792092904},
+ {4.3431460760781775, -2.6869010463029754},
+ {-3.27506631006849, -1.1828954200032116},
+ {-4.3138906717810475, -3.7311556655569875},
+ {4.674799759142193, 3.783941497422669},
+ {3.561677127461424, 1.699651989293141},
+ {-3.0146338910401838, 3.3961817590254952},
+ {-4.438156472502506, 0.5926080631113129},
+ {-4.6425401564313615, 2.131545102584216},
+ {3.2975878235392244, -2.8485717910480988},
+ {-0.9776972765619627, 0.5292861827847535},
+ {-3.9770843662935915, -2.258269867772177},
+ {-4.22908475002643, -4.574457493889454},
+ {-0.28759876443714827, -0.5841999820607002},
+ {2.33121643992511, 1.7993339510854582},
+ {-1.476311475439723, 4.3511414590258894},
+ {4.974472387105775, -4.165990440844669},
+ {-4.570078514420281, 2.156235882831523}
+ };
+
+ writeInputMatrixWithMTD("X", X, true);
+
+ // Capture console output
+ setOutputBuffering(true);
+ String out = runTest(true, false, null, -1).toString();
+
+ // Parse and check L1 norm values
+ String[] lines = out.split(System.lineSeparator());
+ double prevL1Norm = Double.POSITIVE_INFINITY;
+ boolean decreasing = true;
+ int notDecreasingCount = 0; // Counter to track
consecutive non-decreasing values
+ for (String line : lines) {
+ if (line.startsWith("L1 Norm:")) {
+ double l1Norm =
Double.parseDouble(line.substring(9).trim());
+ if (l1Norm >= prevL1Norm) {
+ notDecreasingCount++;
+ if (notDecreasingCount >= 3) {
+ decreasing = false;
+ break; // Exit the loop
once we've seen 3 consecutive non-decreasing values
+ }
+ } else {
+ notDecreasingCount = 0; //
Reset the counter if the current value is decreasing
+ }
+ prevL1Norm = l1Norm;
+ }
+ }
+
+ assertTrue("L1 norm should decrease each time it is printed
out", decreasing);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+
+ }
}
diff --git a/src/test/scripts/functions/builtin/tSNE.dml
b/src/test/scripts/functions/builtin/tSNE.dml
index 8310f75a39..88e7c03910 100644
--- a/src/test/scripts/functions/builtin/tSNE.dml
+++ b/src/test/scripts/functions/builtin/tSNE.dml
@@ -20,5 +20,5 @@
#-------------------------------------------------------------
X = read($X);
-Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $seed,
$is_verbose)
+Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $tol,
$seed, $is_verbose, $print_iter)
write(Y, $Y)