kaknikhil commented on a change in pull request #432: MADLIB-1351 : Added 
stopping criteria on perplexity to LDA
URL: https://github.com/apache/madlib/pull/432#discussion_r324798233
 
 

 ##########
 File path: src/ports/postgres/modules/lda/lda.py_in
 ##########
 @@ -255,15 +282,65 @@ class LDATrainer:
         self.init_random()
         # sstime = time.time()
         for it in range(1, self.iter_num + 1):
+            # JIRA: MADLIB-1351
+            # If the Perplexity_diff is less than the perplexity_tol,
+            # Stop the iteration
+            if self.perplexity_diff < self.perplexity_tol:
+                break
             self.iteration(it)
         # eetime = time.time()
         # plpy.notice('\t\titeration done, time elapsed: %.2f seconds' % 
(eetime - sstime))
 
+
+        # JIRA: MADLIB-1351
+        # Add the last iteration value to the array
+        if self.evaluate_every > 0:
+            self.perplexity_iters.append(self.iter_num)
+
         self.gen_final_data_tables()
 
         # etime = time.time()
         # plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime))
 
+    # Update output table
+    def gen_output_data_table(self, work_table_final):
+        plpy.execute("TRUNCATE TABLE " + self.output_data_table)
+        plpy.execute("""
+            INSERT INTO {output_data_table}
+            SELECT
+                docid, wordcount, words, counts, doc_topic[1:{topic_num}] 
topic_count,
+                doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] 
topic_assignment
+            FROM
+                {work_table_final}
+            """.format(output_data_table=self.output_data_table,
+                       topic_num=self.topic_num,
+                       work_table_final=work_table_final))
+        # etime = time.time()
+        # plpy.notice('\t\t\ttime elapsed: %.2f seconds' % (etime - stime)) 
+
+
+    def calculatePerplexity(self,it, work_table_in):
+        # JIRA: MADLIB-1351
+        # Calculate Perplexity for evaluate_every Iteration
+        # Skip the calculation at the first iteration 
+        # For each iteration: 
+        # Model table is updated (for the first iteration, it is the random 
model. For iteration >1 , the model that is   # updated is learnt in the 
previous iteration)
+        # __lda_count_topic_agg is called then lda_gibbs_sample is called 
which learns and updates the model(the updated  # model is not passed to 
python. The learnt model is updated in the next iteration)
+        # Because of this workflow we can safely ignore the first perplexity 
value.
+        
+
+        if it > self.evaluate_every and self.evaluate_every > 0 and (
+                it - 1) % self.evaluate_every == 0:
+            self.gen_output_data_table(work_table_in)
+            perplexity = 0.0
+            perplexity = get_perplexity(self.schema_madlib,
+                                        self.model_table,
+                                        self.output_data_table)
+            self.perplexity_diff = abs(self.perplexity[
 
 Review comment:
   refactor `self.perplexity[len(self.perplexity)` as `self.perplexity[-1]`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to