This is an automated email from the ASF dual-hosted git repository.

jin pushed a commit to branch tmp
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git

commit ba25a74520b30afc965cb24e854222c11749d3c1
Author: yucui <[email protected]>
AuthorDate: Thu Dec 5 14:20:41 2024 +0800

    refactor: handle multi vids path bug
---
 .../operators/hugegraph_op/graph_rag_query.py      | 50 ++++++++++++----------
 1 file changed, 28 insertions(+), 22 deletions(-)

diff --git 
a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py 
b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index a3dc1ad..a2d113b 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -90,19 +90,7 @@ class GraphRAGQuery:
         self._max_e_prop_len = max_e_prop_len
 
     def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
-        # pylint: disable=R0915 (too-many-statements)
-        if self._client is None:
-            if isinstance(context.get("graph_client"), PyHugeClient):
-                self._client = context["graph_client"]
-            else:
-                ip = context.get("ip") or "localhost"
-                port = context.get("port") or "8080"
-                graph = context.get("graph") or "hugegraph"
-                user = context.get("user") or "admin"
-                pwd = context.get("pwd") or "admin"
-                gs = context.get("graphspace") or None
-                self._client = PyHugeClient(ip, port, graph, user, pwd, gs)
-        assert self._client is not None, "No valid graph to search."
+        self._init_client(context)
 
         # 2. Extract params from context
         matched_vids = context.get("match_vids")
@@ -129,15 +117,18 @@ class GraphRAGQuery:
             log.debug("Vids gremlin query: %s", gremlin_query)
 
             vertex_knowledge = 
self._format_graph_from_vertex(query_result=vertexes)
-            gremlin_query = VID_QUERY_NEIGHBOR_TPL.format(
-                keywords=matched_vids,
-                max_deep=self._max_deep,
-                edge_labels=edge_labels_str,
-                edge_limit=edge_limit_amount,
-                max_items=self._max_items,
-            )
-            log.debug("Kneighbor gremlin query: %s", gremlin_query)
-            paths = self._client.gremlin().exec(gremlin=gremlin_query)["data"]
+            paths: List[Any] = []
+            # TODO: 这里后续改为使用生成器 or 异步 asycnio 处理以提高性能
+            for matched_vid in matched_vids:
+                gremlin_query = VID_QUERY_NEIGHBOR_TPL.format(
+                    keywords="'{}'".format(matched_vid),
+                    max_deep=self._max_deep,
+                    edge_labels=edge_labels_str,
+                    edge_limit=edge_limit_amount,
+                    max_items=self._max_items,
+                )
+                log.debug("Kneighbor gremlin query: %s", gremlin_query)
+                
paths.extend(self._client.gremlin().exec(gremlin=gremlin_query)["data"])
 
             graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = 
self._format_graph_query_result(
                 query_paths=paths
@@ -182,6 +173,21 @@ class GraphRAGQuery:
         log.debug("Knowledge from Graph:\n%s", 
"\n".join(context["graph_result"]))
         return context
 
+    def _init_client(self, context):
+        # pylint: disable=R0915 (too-many-statements)
+        if self._client is None:
+            if isinstance(context.get("graph_client"), PyHugeClient):
+                self._client = context["graph_client"]
+            else:
+                ip = context.get("ip") or "localhost"
+                port = context.get("port") or "8080"
+                graph = context.get("graph") or "hugegraph"
+                user = context.get("user") or "admin"
+                pwd = context.get("pwd") or "admin"
+                gs = context.get("graphspace") or None
+                self._client = PyHugeClient(ip, port, graph, user, pwd, gs)
+        assert self._client is not None, "No valid graph to search."
+
     def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
         knowledge = set()
         for item in query_result:

Reply via email to