imbajin commented on code in PR #105:
URL: 
https://github.com/apache/incubator-hugegraph-ai/pull/105#discussion_r1856470801


##########
hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py:
##########
@@ -19,21 +19,43 @@
 import os
 from typing import Dict, Any
 
+from tqdm import tqdm
+import pandas as pd
 from hugegraph_llm.config import resource_path
 from hugegraph_llm.models.embeddings.base import BaseEmbedding
 from hugegraph_llm.indices.vector_index import VectorIndex
+from hugegraph_llm.models.embeddings.init_embedding import Embeddings
+from hugegraph_llm.utils.log import log
 
 
 class GremlinExampleIndexQuery:
-    def __init__(self, query: str, embedding: BaseEmbedding, num_examples: int 
= 1):
-        self.query = query
-        self.embedding = embedding
+    def __init__(self, embedding: BaseEmbedding = None, num_examples: int = 1):
+        self.embedding = embedding or Embeddings().get_embedding()
         self.num_examples = num_examples
         self.index_dir = os.path.join(resource_path, "gremlin_examples")
+        if not (os.path.exists(os.path.join(self.index_dir, "index.faiss"))
+                and os.path.exists(os.path.join(self.index_dir, 
"properties.pkl"))):
+            log.warning("No gremlin example index found, will generate one.")
+            self._build_default_example_index()
         self.vector_index = VectorIndex.from_index_file(self.index_dir)
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
-        context["query"] = self.query
-        query_embedding = self.embedding.get_text_embedding(self.query)
-        context["match_result"] = self.vector_index.search(query_embedding, 
self.num_examples)
+        query = context.get("query")
+        assert query, "query is required"
+        if self.num_examples == 0:
+            context["match_result"] = []
+        else:
+            if "query_embedding" in context and 
isinstance(context["query_embedding"], list):
+                query_embedding = context["query_embedding"]
+            else:
+                query_embedding = self.embedding.get_text_embedding(query)
+            context["match_result"] = 
self.vector_index.search(query_embedding, self.num_examples, dis_threshold=2)

Review Comment:
   why we set `dis_threshold` to 2? (means always match topK?)
   
   Seems not reasonable 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to