This is an automated email from the ASF dual-hosted git repository.

janardhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 60a2acbf4b [SYSTEMDS-3670] Added early-stopping mechanism to tSNE 
(#1990)
60a2acbf4b is described below

commit 60a2acbf4bcfcb1e41a1c12be05504ffafa02720
Author: ramesesz <[email protected]>
AuthorDate: Fri Feb 9 06:05:43 2024 +0100

    [SYSTEMDS-3670] Added early-stopping mechanism to tSNE (#1990)
    
    This patch improves the builtin dist function by removing the outer product 
operator. For 100 function calls on an arbitrary matrix with 4000 rows and 800 
cols, the new dist function shortens the runtime from 66.541s to 60.268s.
---
 scripts/builtin/tSNE.dml                           |  66 +++++++---
 .../functions/builtin/part2/BuiltinTSNETest.java   | 144 ++++++++++++++++++++-
 src/test/scripts/functions/builtin/tSNE.dml        |   2 +-
 3 files changed, 191 insertions(+), 21 deletions(-)

diff --git a/scripts/builtin/tSNE.dml b/scripts/builtin/tSNE.dml
index a28a1c1a0a..e4af10fbd8 100644
--- a/scripts/builtin/tSNE.dml
+++ b/scripts/builtin/tSNE.dml
@@ -41,9 +41,12 @@
 # lr             Learning rate
 # momentum       Momentum Parameter
 # max_iter       Number of iterations
+# tol            Tolerance for early stopping in gradient descent
 # seed           The seed used for initial values.
 #                If set to -1 random seeds are selected.
 # is_verbose     Print debug information
+# print_iter     Intervals of printing out the L1 norm values. Parameter not 
relevant if
+#                is_verbose = FALSE.
 # 
-------------------------------------------------------------------------------------------
 #
 # OUTPUT:
@@ -52,7 +55,8 @@
 # 
-------------------------------------------------------------------------------------------
 
 m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer 
perplexity = 30,
-  Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Integer 
seed = -1, Boolean is_verbose = FALSE)
+  Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol 
= 1e-5, 
+  Integer seed = -1, Boolean is_verbose = FALSE, Integer print_iter = 10)
   return(Matrix[Double] Y)
 {
   d = reduced_dims
@@ -73,8 +77,41 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 
2, Integer perplexity
   if(is_verbose)
     print("starting loop....")
 
-  for (itr in 1:max_iter) {
-    D = distance_matrix(Y)
+  itr = 1
+
+  # Start first iteration out of loop as benchmark for early stopping
+  D = dist(Y)
+  Z = 1/(D + 1)
+  Z = Z * ZERODIAG
+  Q = Z/sum(Z)
+  W = (P - Q)*Z
+  sumW = rowSums(W)
+  g = Y * sumW - W %*% Y
+  dY = momentum*dY - lr*g
+       
+  norm = sum(dY^2)
+  norm_initial = norm
+  norm_target = norm_initial * tol
+
+  if(is_verbose){
+    print("L1 Norm initial : " + norm_initial)
+    print("L1 Norm target  : " + norm_target)
+  }
+
+  Y = Y + dY
+  Y = Y - colMeans(Y)
+
+  if (itr%%100 == 0) {
+    C[itr/100,] = sum(P * log(pmax(P, 1e-12) / pmax(Q, 1e-12)))
+  }
+  if (itr == 100) {
+    P = P/4
+  }
+  itr = itr + 1
+  # End of first iteration
+
+  while (itr <= max_iter & norm > norm_target) {
+    D = dist(Y)
     Z = 1/(D + 1)
     Z = Z * ZERODIAG
     Q = Z/sum(Z)
@@ -82,6 +119,13 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 
2, Integer perplexity
     sumW = rowSums(W)
     g = Y * sumW - W %*% Y
     dY = momentum*dY - lr*g
+
+    norm = sum(dY^2)
+    if(is_verbose & itr %% print_iter == 0){
+      print("Iteration: " + itr)
+      print("L1 Norm: " + norm)
+    }
+
     Y = Y + dY
     Y = Y - colMeans(Y)
 
@@ -91,20 +135,10 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 
2, Integer perplexity
     if (itr == 100) {
       P = P/4
     }
+    itr = itr + 1
   }
 }
 
-distance_matrix = function(matrix[double] X)
-  return (matrix[double] out)
-{
-  # TODO consolidate with dist() builtin, but with
-  # better way of obtaining the diag from 
-  n = nrow(X)
-  s = rowSums(X * X)
-  out = - 2*X %*% t(X) + s + t(s)
-}
-
-
 x2p = function(matrix[double] X, double perplexity, Boolean is_verbose = FALSE)
 return(matrix[double] P)
 {
@@ -115,7 +149,7 @@ return(matrix[double] P)
   n = nrow(X)
   if(is_verbose)
     print(n)
-  D = distance_matrix(X)
+  D = dist(X)
 
   P = matrix(0, rows=n, cols=n)
   beta = matrix(1, rows=n, cols=1)
@@ -129,7 +163,7 @@ return(matrix[double] P)
   while (mean(abs(Hdiff)) > tol & itr < 50) {
     P = exp(-D * beta)
     P = P * ZERODIAG
-    sum_Pi = rowSums(P)
+    sum_Pi = rowSums(P) + 1e-12
     W = rowSums(P * D)
     Ws = W/sum_Pi
     H = log(sum_Pi) + beta * Ws
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
index 41d2d4606f..44d06ffaf9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java
@@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.junit.Test;
+import static org.junit.Assert.assertTrue;
 
 import java.io.IOException;
 
@@ -41,12 +42,12 @@ public class BuiltinTSNETest extends AutomatedTestBase
        @Test
        public void testTSNECP() throws IOException {
                runTSNETest(2, 30, 300.,
-                       0.9, 1000, 42, "FALSE", ExecType.CP);
+                       0.9, 1000, 1e-5d, 42, "FALSE", 10, ExecType.CP);
        }
 
        @SuppressWarnings("unused")
-       private void runTSNETest(Integer reduced_dims, Integer perplexity, 
Double lr,
-               Double momentum, Integer max_iter, Integer seed, String 
is_verbose, ExecType instType)
+       private void runTSNETest(int reduced_dims, int perplexity, double lr,
+               double momentum, int max_iter, double tol, int seed, String 
is_verbose, Integer print_iter, ExecType instType)
                throws IOException
        {
                ExecMode platformOld = setExecMode(instType);
@@ -64,8 +65,11 @@ public class BuiltinTSNETest extends AutomatedTestBase
                                "lr=" + lr,
                                "momentum=" + momentum,
                                "max_iter=" + max_iter,
+                               "tol=" + tol,
                                "seed=" + seed,
-                               "is_verbose=" + is_verbose};
+                               "is_verbose=" + is_verbose,
+                               "print_iter=" + print_iter
+                       };
 
                        // The Input values are calculated using the following 
R script:
                        // TODO create via dml operations, avoid inlining data
@@ -403,4 +407,136 @@ public class BuiltinTSNETest extends AutomatedTestBase
                        rtplatform = platformOld;
                }
        }
+
+
+       @Test
+       public void testTSNEEarlyStopping() throws IOException {
+               // Test setup guarantees early stopping.
+               runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, 
"TRUE", 10, ExecType.CP);
+       }
+
+       @SuppressWarnings("unused")
+       private void runTSNEEarlyStoppingTest(
+               Integer reduced_dims, 
+               Integer perplexity, 
+               Double lr,
+               Double momentum, 
+               Integer max_iter, 
+               Double tol, 
+               Integer seed, 
+               String is_verbose, 
+               Integer print_iter,
+               ExecType instType) throws IOException {
+               
+               ExecMode platformOld = setExecMode(instType);
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{
+                               "-nvargs", "X=" + input("X"), "Y=" + 
output("Y"),
+                               "reduced_dims=" + reduced_dims,
+                               "perplexity=" + perplexity,
+                               "lr=" + lr,
+                               "momentum=" + momentum,
+                               "max_iter=" + max_iter,
+                               "tol=" + tol,
+                               "seed=" + seed,
+                               "is_verbose=" + is_verbose,
+                               "print_iter=" + print_iter
+                       };
+
+                       // The Input values are calculated using the following 
dml script:
+                       // X = rand(rows=50, cols=2, min=0, max=5, seed=1)
+
+                       // Input
+                       double[][] X = {
+                               {-0.45700987356506406, 2.834752454661148},
+                               {-1.3945444464226533, -3.8794723634582597},
+                               {-1.338576451510809, 1.160918547857504},
+                               {2.921891699889728, -1.32749074959577},
+                               {1.3001464754324763, -1.1353208514333533},
+                               {0.2866401390950628, 3.359214248871961},
+                               {2.6740553056629217, -1.2030274345674852},
+                               {1.7240446900374895, 3.4430052477647557},
+                               {-0.3435254305219493, 4.205393963204703},
+                               {-2.873899183923896, 1.098272406118296},
+                               {4.890217056606042, 1.5814575251762104},
+                               {-4.920042511612875, 4.579455675519821},
+                               {1.439881754507784, -4.090781835042895},
+                               {2.32372435941579, 4.823050596338641},
+                               {-0.9864739586714544, -1.6990853495458147},
+                               {4.605792626050157, 2.411639339263437},
+                               {4.979120527950069, 1.7181757158820465},
+                               {-4.423608438974177, 0.44712526968937283},
+                               {3.4109472479162317, -3.497269670333382},
+                               {-1.9938801849366037, -1.1880069697833906},
+                               {3.223381639747396, 3.7784510177449793},
+                               {2.10470587687118, 0.5415570090498525},
+                               {2.084254693325721, 1.4369473809787037},
+                               {-0.9957311983302795, 1.586795215124286},
+                               {-3.7527381124013894, 4.3818220996816475},
+                               {3.5748622228245193, 1.116518048277384},
+                               {-2.297351475873446, -2.0179124546489047},
+                               {-0.3438938003649259, 0.689249021371154},
+                               {-0.8823286368673617, 1.2731356499886672},
+                               {2.517220722615252, -2.8806532181877254},
+                               {3.923092638022041, 4.34404320783608},
+                               {-2.1012040153953, -4.33147229525127},
+                               {3.5992422607685715, 2.5628828792092904},
+                               {4.3431460760781775, -2.6869010463029754},
+                               {-3.27506631006849, -1.1828954200032116},
+                               {-4.3138906717810475, -3.7311556655569875},
+                               {4.674799759142193, 3.783941497422669},
+                               {3.561677127461424, 1.699651989293141},
+                               {-3.0146338910401838, 3.3961817590254952},
+                               {-4.438156472502506, 0.5926080631113129},
+                               {-4.6425401564313615, 2.131545102584216},
+                               {3.2975878235392244, -2.8485717910480988},
+                               {-0.9776972765619627, 0.5292861827847535},
+                               {-3.9770843662935915, -2.258269867772177},
+                               {-4.22908475002643, -4.574457493889454},
+                               {-0.28759876443714827, -0.5841999820607002},
+                               {2.33121643992511, 1.7993339510854582},
+                               {-1.476311475439723, 4.3511414590258894},
+                               {4.974472387105775, -4.165990440844669},
+                               {-4.570078514420281, 2.156235882831523}
+                       };
+
+                       writeInputMatrixWithMTD("X", X, true);
+
+                       // Capture console output
+                       setOutputBuffering(true);
+                       String out = runTest(true, false, null, -1).toString();
+
+                       // Parse and check L1 norm values
+                       String[] lines = out.split(System.lineSeparator());
+                       double prevL1Norm = Double.POSITIVE_INFINITY;
+                       boolean decreasing = true;
+                       int notDecreasingCount = 0; // Counter to track 
consecutive non-decreasing values
+                       for (String line : lines) {
+                               if (line.startsWith("L1 Norm:")) {
+                                       double l1Norm = 
Double.parseDouble(line.substring(9).trim());
+                                       if (l1Norm >= prevL1Norm) {
+                                               notDecreasingCount++;
+                                               if (notDecreasingCount >= 3) {
+                                                       decreasing = false;
+                                                       break; // Exit the loop 
once we've seen 3 consecutive non-decreasing values
+                                               }
+                                       } else {
+                                               notDecreasingCount = 0; // 
Reset the counter if the current value is decreasing
+                                       }
+                                       prevL1Norm = l1Norm;
+                               }
+                       }
+
+               assertTrue("L1 norm should decrease each time it is printed 
out", decreasing);
+               } 
+               finally {
+                       rtplatform = platformOld;
+               }
+
+       }
 }
diff --git a/src/test/scripts/functions/builtin/tSNE.dml 
b/src/test/scripts/functions/builtin/tSNE.dml
index 8310f75a39..88e7c03910 100644
--- a/src/test/scripts/functions/builtin/tSNE.dml
+++ b/src/test/scripts/functions/builtin/tSNE.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 X = read($X);
-Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $seed, 
$is_verbose)
+Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $tol, 
$seed, $is_verbose, $print_iter)
 write(Y, $Y)

Reply via email to