imbajin commented on code in PR #105:
URL:
https://github.com/apache/incubator-hugegraph-ai/pull/105#discussion_r1856470801
##########
hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py:
##########
@@ -19,21 +19,43 @@
import os
from typing import Dict, Any
+from tqdm import tqdm
+import pandas as pd
from hugegraph_llm.config import resource_path
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.indices.vector_index import VectorIndex
+from hugegraph_llm.models.embeddings.init_embedding import Embeddings
+from hugegraph_llm.utils.log import log
class GremlinExampleIndexQuery:
- def __init__(self, query: str, embedding: BaseEmbedding, num_examples: int
= 1):
- self.query = query
- self.embedding = embedding
+ def __init__(self, embedding: BaseEmbedding = None, num_examples: int = 1):
+ self.embedding = embedding or Embeddings().get_embedding()
self.num_examples = num_examples
self.index_dir = os.path.join(resource_path, "gremlin_examples")
+ if not (os.path.exists(os.path.join(self.index_dir, "index.faiss"))
+ and os.path.exists(os.path.join(self.index_dir,
"properties.pkl"))):
+ log.warning("No gremlin example index found, will generate one.")
+ self._build_default_example_index()
self.vector_index = VectorIndex.from_index_file(self.index_dir)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
- context["query"] = self.query
- query_embedding = self.embedding.get_text_embedding(self.query)
- context["match_result"] = self.vector_index.search(query_embedding,
self.num_examples)
+ query = context.get("query")
+ assert query, "query is required"
+ if self.num_examples == 0:
+ context["match_result"] = []
+ else:
+ if "query_embedding" in context and
isinstance(context["query_embedding"], list):
+ query_embedding = context["query_embedding"]
+ else:
+ query_embedding = self.embedding.get_text_embedding(query)
+ context["match_result"] =
self.vector_index.search(query_embedding, self.num_examples, dis_threshold=2)
Review Comment:
why we set `dis_threshold` to 2? (means always match topK?)
Seems not reasonable
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]