This is an automated email from the ASF dual-hosted git repository.
jin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-ai.git
The following commit(s) were added to refs/heads/main by this push:
new 2eb7834b feat(llm): update keyword extraction method (BREAKING CHANGE)
(#282)
2eb7834b is described below
commit 2eb7834bb6e859599a088a20c9c4328f56d2a5e6
Author: Frui Guo <[email protected]>
AuthorDate: Tue Oct 21 17:33:31 2025 +0800
feat(llm): update keyword extraction method (BREAKING CHANGE) (#282)
BREAKING CHANGE
**MUST** :UPDATE YOUR "KEYWORD EXTRACT PROMPT" To LATEST VERSION
fix #224 problem, update the new UI to support change keyword extraction
method.
**Main changes**
Added options to the RAG interface for selecting the keyword extraction
method(including LLM, TextRank, Hybrid) and the max number of keywords.
<img width="619" height="145" alt="QQ20250818-193453"
src="https://github.com/user-attachments/assets/3c0d21f0-82bb-4176-bfe2-1b0744c06b6d"
/>
A 'TextRank mask words' setting has also been added. It allows users to
manually input specific phrases composed of letters and symbols to
prevent them from being split during word segmentation. And the input
will also be saved.
<img width="1207" height="263" alt="QQ20250818-193518"
src="https://github.com/user-attachments/assets/6366789a-f87d-46a4-a85a-9f3b4d9ce9a5"
/>
**Test results**
TextRank Method:
-Input
<img width="363" height="144" alt="image"
src="https://github.com/user-attachments/assets/4a6267f7-3982-4fca-82df-60cd55bed6af"
/>
-Result:
<img width="232" height="118" alt="image"
src="https://github.com/user-attachments/assets/54a34d00-e588-44ad-9eff-d7281d7d93e5"
/>
Hybrid Method:
<img width="710" height="129" alt="QQ20250818-193508"
src="https://github.com/user-attachments/assets/541534fd-cec0-4002-9967-e49954a6c19e"
/>
---------
Co-authored-by: imbajin <[email protected]>
---
hugegraph-llm/.gitignore | 6 +-
hugegraph-llm/config.md | 19 ++-
hugegraph-llm/pyproject.toml | 3 +
hugegraph-llm/src/hugegraph_llm/api/rag_api.py | 3 +-
.../src/hugegraph_llm/config/llm_config.py | 4 +
.../config/models/base_prompt_config.py | 2 +-
.../src/hugegraph_llm/config/prompt_config.py | 22 ++-
.../src/hugegraph_llm/demo/rag_demo/app.py | 3 +-
.../src/hugegraph_llm/demo/rag_demo/rag_block.py | 8 +-
.../operators/common_op/nltk_helper.py | 62 +++++++-
.../operators/document_op/textrank_word_extract.py | 151 +++++++++++++++++++
.../operators/document_op/word_extract.py | 10 +-
.../src/hugegraph_llm/operators/graph_rag_task.py | 19 +--
.../operators/llm_op/keyword_extract.py | 164 +++++++++++++++------
pyproject.toml | 3 +
15 files changed, 390 insertions(+), 89 deletions(-)
diff --git a/hugegraph-llm/.gitignore b/hugegraph-llm/.gitignore
index cf45784b..85267f3a 100644
--- a/hugegraph-llm/.gitignore
+++ b/hugegraph-llm/.gitignore
@@ -1,7 +1,7 @@
src/hugegraph_llm/resources/*
-!/src/hugegraph_llm/resources/demo/*
-!/src/hugegraph_llm/resources/nltk_data/*
-!/src/hugegraph_llm/resources/prompt_examples/*
+!/src/hugegraph_llm/resources/demo/
+!/src/hugegraph_llm/resources/nltk_data/corpora/stopwords/
+!/src/hugegraph_llm/resources/prompt_examples/
uv.lock
diff --git a/hugegraph-llm/config.md b/hugegraph-llm/config.md
index a55172f3..448661c9 100644
--- a/hugegraph-llm/config.md
+++ b/hugegraph-llm/config.md
@@ -26,14 +26,17 @@
### 基础配置
-| 配置项 | 类型
| 默认值 | 说明 |
-|---------------------|--------------------------------------------------------|--------|---------------------------------------|
-| `LANGUAGE` | Literal["EN", "CN"]
| EN | prompt语言,支持 EN(英文)和 CN(中文) |
-| `CHAT_LLM_TYPE` | Literal["openai", "litellm", "ollama/local"]
| openai | 聊天 LLM 类型:openai/litellm/ollama/local |
-| `EXTRACT_LLM_TYPE` | Literal["openai", "litellm", "ollama/local"]
| openai | 信息提取 LLM 类型 |
-| `TEXT2GQL_LLM_TYPE` | Literal["openai", "litellm", "ollama/local"]
| openai | 文本转 GQL LLM 类型 |
-| `EMBEDDING_TYPE` | Optional[Literal["openai", "litellm", "ollama/local"]]
| openai | 嵌入模型类型 |
-| `RERANKER_TYPE` | Optional[Literal["cohere", "siliconflow"]]
| None | 重排序模型类型:cohere/siliconflow |
+| 配置项 | 类型
| 默认值 | 说明 |
+|------------------------|--------------------------------------------------------|--------|---------------------------------------|
+| `LANGUAGE` | Literal["EN", "CN"]
| EN | prompt语言,支持 EN(英文)和 CN(中文) |
+| `CHAT_LLM_TYPE` | Literal["openai", "litellm", "ollama/local"]
| openai | 聊天 LLM 类型:openai/litellm/ollama/local |
+| `EXTRACT_LLM_TYPE` | Literal["openai", "litellm", "ollama/local"]
| openai | 信息提取 LLM 类型 |
+| `TEXT2GQL_LLM_TYPE` | Literal["openai", "litellm", "ollama/local"]
| openai | 文本转 GQL LLM 类型 |
+| `EMBEDDING_TYPE` | Optional[Literal["openai", "litellm",
"ollama/local"]] | openai | 嵌入模型类型 |
+| `RERANKER_TYPE` | Optional[Literal["cohere", "siliconflow"]]
| None | 重排序模型类型:cohere/siliconflow |
+| `KEYWORD_EXTRACT_TYPE` | Literal["llm", "textrank", "hybrid"]
| llm | 关键词提取模型类型:llm/textrank/hybrid |
+| `WINDOW_SIZE` | Optional[Integer] | 3 | TextRank 滑窗大小 (范围:
1-10),较大的窗口可以捕获更长距离的词语关系,但会增加计算复杂度 |
+| `HYBRID_LLM_WEIGHTS` | Optional[Float] | 0.5 | 混合模式中 LLM 结果的权重 (范围:
0.0-1.0),TextRank 权重 = 1 - 该值。推荐 0.5 以平衡两种方法 |
### OpenAI 配置
diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml
index 2ed43896..1bd3b748 100644
--- a/hugegraph-llm/pyproject.toml
+++ b/hugegraph-llm/pyproject.toml
@@ -39,6 +39,9 @@ dependencies = [
"numpy",
"pandas",
"pydantic",
+ "scipy",
+ "python-igraph",
+
# LLM specific dependencies
"openai",
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index c39c7771..bfa76e7e 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -28,11 +28,12 @@ from hugegraph_llm.api.models.rag_requests import (
GraphRAGRequest,
GremlinGenerateRequest,
)
-from hugegraph_llm.config import huge_settings
from hugegraph_llm.api.models.rag_response import RAGResponse
+from hugegraph_llm.config import huge_settings
from hugegraph_llm.config import llm_settings, prompt
from hugegraph_llm.utils.log import log
+
# pylint: disable=too-many-statements
def rag_http_api(
router: APIRouter,
diff --git a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
index b2029d98..64d851f5 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py
@@ -30,6 +30,10 @@ class LLMConfig(BaseConfig):
text2gql_llm_type: Literal["openai", "litellm", "ollama/local"] = "openai"
embedding_type: Optional[Literal["openai", "litellm", "ollama/local"]] =
"openai"
reranker_type: Optional[Literal["cohere", "siliconflow"]] = None
+ keyword_extract_type: Literal["llm", "textrank", "hybrid"] = "llm"
+ window_size: Optional[int] = 3
+ hybrid_llm_weights: Optional[float] = 0.5
+ # TODO: divide RAG part if necessary
# 1. OpenAI settings
openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL",
"https://api.openai.com/v1")
openai_chat_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
diff --git
a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
index 7af1ef92..1008c3c1 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
import os
+import sys
from pathlib import Path
import yaml
diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
index e5e1c926..eaccbefa 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py
@@ -158,9 +158,12 @@ Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a
# Extracted from llm_op/keyword_extract.py
keywords_extract_prompt_EN: str = """Instructions:
Please perform the following tasks on the text below:
- 1. Extract keywords from the text:
+ 1. Extract, evaluate, and rank keywords from the text:
- Minimum 0, maximum MAX_KEYWORDS keywords.
- - Keywords should be complete semantic words or phrases, ensuring
information completeness.
+ - Keywords should be complete semantic words or phrases, ensuring
information completeness, without any changes to the English capitalization.
+ - Assign an importance score to each keyword, as a float between 0.0
and 1.0. A higher score indicates a greater contribution to the core idea of
the text.
+ - Keywords may contain spaces, but must not contain commas or colons.
+ - The final list of keywords must be sorted in descending order based
on their importance score.
2. Identify keywords that need rewriting:
- From the extracted keywords, identify those that are ambiguous or
lack information in the original context.
3. Generate synonyms:
@@ -179,9 +182,9 @@ Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a
- Adjust keyword length: If keywords are relatively broad, you can
appropriately increase individual keyword length based on context (e.g.,
"illegal behavior" can be extracted as a single keyword, or as "illegal", but
should not be split into "illegal" and "behavior").
Output Format:
- - Output only one line, prefixed with KEYWORDS:, followed by all keywords
or corresponding synonyms, separated by commas. No spaces or empty characters
are allowed in the extracted keywords.
+ - Output only one line, prefixed with KEYWORDS:, followed by a
comma-separated list of items. Each item should be in the format
keyword:importance_score(round to two decimal places). If a keyword has been
replaced by a synonym, use the synonym as the keyword in the output.
- Format example:
- KEYWORDS:keyword1,keyword2,...,keywordN
+ KEYWORDS:keyword1:score1,keyword2:score2,...,keywordN:scoreN
MAX_KEYWORDS: {max_keywords}
Text:
@@ -366,9 +369,12 @@ g.V().limit(10)
keywords_extract_prompt_CN: str = """指令:
请对以下文本执行以下任务:
-1. 从文本中提取关键词:
+1. 从文本中提取、评估与排序关键词:
- 最少 0 个,最多 MAX_KEYWORDS 个。
- - 关键词应为具有完整语义的词语或短语,确保信息完整。
+ - 关键词应为具有完整语义的词语或短语,确保信息完整,英文大小写不做改动。
+ - 为每个关键词进行重要性评分,分值在 0.0 到 1.0 之间,浮点数表示,分数越高代表其对文本核心思想的贡献越大。
+ - 关键词内不得包含逗号或冒号(用于分隔)。
+ - 最终输出的关键词列表必须按照重要性评分 **从高到低** 进行排序。
2. 识别需改写的关键词:
- 从提取的关键词中,识别那些在原语境中具有歧义或存在信息缺失的关键词。
3. 生成同义词:
@@ -384,9 +390,9 @@ g.V().limit(10)
- 仅考虑语境相关的同义词:只需考虑给定语境下的关键词的语义近义词和具有类似含义的其他词语。
-
调整关键词长度:如果关键词相对宽泛,可以根据语境适当增加单个关键词的长度(例如:“违法行为”可以作为一个单独的关键词被抽取,或抽取为“违法”,但不应拆分为“违法”和“行为”)。
输出格式:
-- 仅输出一行内容,以 KEYWORDS: 为前缀,后跟所有关键词或对应的同义词,之间用逗号分隔。抽取的关键词中不允许出现空格或空字符
+- 仅输出一行内容,以 KEYWORDS: 为前缀,后跟列表项,关键词提取列表项为
关键词:重要性评分,评分建议保留两位小数,同义词提取列表项为对应的同义词,列表项之间用逗号分隔。
- 格式示例:
-KEYWORDS:关键词 1,关键词 2,...,关键词 n
+KEYWORDS:关键词_1:分数_1,关键词_2:分数_2,...,关键词_n:分数_n
MAX_KEYWORDS: {max_keywords}
文本:
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
index 2f9c3b34..4e575ddd 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
@@ -16,6 +16,7 @@
# under the License.
import argparse
+
import gradio as gr
import uvicorn
from fastapi import FastAPI, Depends, APIRouter
@@ -101,7 +102,7 @@ def init_rag_ui() -> gr.Interface:
textbox_inp,
textbox_answer_prompt_input,
textbox_keywords_extract_prompt_input,
- textbox_custom_related_information,
+ textbox_custom_related_information
) = create_rag_block()
with gr.Tab(label="3. Text2gremlin ⚙️"):
textbox_gremlin_inp, textbox_gremlin_schema,
textbox_gremlin_prompt = (
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
index cc6bb44e..982436b0 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py
@@ -18,18 +18,19 @@
# pylint: disable=E1101
import os
-from typing import AsyncGenerator, Tuple, Literal, Optional
+from typing import AsyncGenerator, Literal, Optional, Tuple
import gradio as gr
import pandas as pd
from gradio.utils import NamedString
-from hugegraph_llm.config import resource_path, prompt, huge_settings,
llm_settings
+from hugegraph_llm.config import huge_settings, llm_settings, prompt,
resource_path
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
-from hugegraph_llm.utils.decorators import with_task_id
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
+from hugegraph_llm.utils.decorators import with_task_id
from hugegraph_llm.utils.log import log
+
def rag_answer(
text: str,
raw_answer: bool,
@@ -261,7 +262,6 @@ def create_rag_block():
show_copy_button=True,
latex_delimiters=[{"left": "$", "right": "$", "display":
False}],
)
-
answer_prompt_input = gr.Textbox(
value=prompt.answer_prompt, label="Query Prompt",
show_copy_button=True, lines=7
)
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
b/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
index 797ea70a..30b2c649 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/nltk_helper.py
@@ -15,16 +15,17 @@
# specific language governing permissions and limitations
# under the License.
-
import os
import sys
from pathlib import Path
from typing import List, Optional, Dict
+from urllib.error import URLError, HTTPError
import nltk
from nltk.corpus import stopwords
from hugegraph_llm.config import resource_path
+from hugegraph_llm.utils.log import log
class NLTKHelper:
@@ -35,7 +36,9 @@ class NLTKHelper:
def stopwords(self, lang: str = "chinese") -> List[str]:
"""Get stopwords."""
- nltk.data.path.append(os.path.join(resource_path, "nltk_data"))
+ _hugegraph_source_dir = os.path.join(resource_path, "nltk_data")
+ if _hugegraph_source_dir not in nltk.data.path:
+ nltk.data.path.append(_hugegraph_source_dir)
if self._stopwords.get(lang) is None:
cache_dir = self.get_cache_dir()
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
@@ -47,11 +50,64 @@ class NLTKHelper:
try:
nltk.data.find("corpora/stopwords")
except LookupError:
- nltk.download("stopwords", download_dir=nltk_data_dir)
+ try:
+ log.info("Start download nltk package stopwords")
+ nltk.download("stopwords", download_dir=nltk_data_dir,
quiet=False)
+ log.debug("NLTK package stopwords is already downloaded")
+ except (URLError, HTTPError, PermissionError) as e:
+ log.warning("Can't download package stopwords as error:
%s", e)
+ try:
self._stopwords[lang] = stopwords.words(lang)
+ except LookupError as e:
+ log.error("NLTK stopwords for lang=%s not found: %s; using empty
list", lang, e)
+ self._stopwords[lang] = []
+
+ # final check
+ final_stopwords = self._stopwords[lang]
+ if final_stopwords is None:
+ return []
return self._stopwords[lang]
+ def check_nltk_data(self):
+ _hugegraph_source_dir = os.path.join(resource_path, "nltk_data")
+ if _hugegraph_source_dir not in nltk.data.path:
+ nltk.data.path.append(_hugegraph_source_dir)
+
+ cache_dir = self.get_cache_dir()
+ nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
+ if nltk_data_dir not in nltk.data.path:
+ nltk.data.path.append(nltk_data_dir)
+
+ required_packages = {
+ 'punkt': 'tokenizers/punkt',
+ 'punkt_tab': 'tokenizers/punkt_tab',
+ 'averaged_perceptron_tagger': 'taggers/averaged_perceptron_tagger',
+ "averaged_perceptron_tagger_eng":
'taggers/averaged_perceptron_tagger_eng'
+ }
+
+ for package, path in required_packages.items():
+ try:
+ nltk.data.find(path)
+ except LookupError:
+ log.info("Start download nltk package %s", package)
+ try:
+ if not nltk.download(package, download_dir=nltk_data_dir,
quiet=False):
+ log.warning("NLTK download command returned False for
package %s.", package)
+ return False
+ # Verify after download
+ nltk.data.find(path)
+ except PermissionError as e:
+ log.error("Permission denied when downloading %s: %s",
package, e)
+ return False
+ except (URLError, HTTPError) as e:
+ log.warning("Network error downloading %s: %s, will retry
with backup method", package, e)
+ return False
+ except LookupError:
+ log.error("Package %s not found after download. Check
package name and nltk_data paths.", package)
+ return False
+ return True
+
@staticmethod
def get_cache_dir() -> str:
"""Locate a platform-appropriate cache directory for hugegraph-llm,
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py
new file mode 100644
index 00000000..1bd17c73
--- /dev/null
+++
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/textrank_word_extract.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+from collections import defaultdict
+from typing import Dict
+
+import igraph as ig
+import jieba.posseg as pseg
+import nltk
+import regex
+
+from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
+from hugegraph_llm.utils.log import log
+
+
+class MultiLingualTextRank:
+ def __init__(self, keyword_num: int = 5, window_size: int = 3):
+ self.top_k = keyword_num
+ self.window = window_size if 0 < window_size <= 10 else 3
+ self.graph = None
+ self.max_len = 100
+
+ self.pos_filter = {
+ 'chinese': ('n', 'nr', 'ns', 'nt', 'nrt', 'nz', 'v', 'vd', 'vn',
"eng", "j", "l"),
+ 'english': ('NN', 'NNS', 'NNP', 'NNPS', 'VB', 'VBG', 'VBN', 'VBZ')
+ }
+ self.rules = [r"https?://\S+|www\.\S+",
+ r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b",
+ r"\b\w+(?:[-’\']\w+)+\b",
+ r"\b\d+[,.]\d+\b"]
+
+ def _word_mask(self, text):
+
+ placeholder_id_counter = 0
+ placeholder_map = {}
+
+ def _create_placeholder(match_obj):
+ nonlocal placeholder_id_counter
+ original_word = match_obj.group(0)
+ _placeholder = f" __shieldword_{placeholder_id_counter}__ "
+ placeholder_map[_placeholder.strip()] = original_word
+ placeholder_id_counter += 1
+ return _placeholder
+
+ special_regex = regex.compile('|'.join(self.rules), regex.V1)
+ text = special_regex.sub(_create_placeholder, text)
+
+ return text, placeholder_map
+
+ @staticmethod
+ def _get_valid_tokens(masked_text):
+ patterns_to_keep = [
+ r'__shieldword_\d+__',
+ r'\b\w+\b',
+ r'[\u4e00-\u9fff]+'
+ ]
+ combined_pattern = re.compile('|'.join(patterns_to_keep),
re.IGNORECASE)
+ tokens = combined_pattern.findall(masked_text)
+ text_for_nltk = ' '.join(tokens)
+ nltk_tokens = nltk.word_tokenize(text_for_nltk)
+ pos_tags = nltk.pos_tag(nltk_tokens)
+ return pos_tags
+
+ def _multi_preprocess(self, text):
+ words = []
+ en_stop_words = NLTKHelper().stopwords(lang='english')
+ ch_stop_words = NLTKHelper().stopwords(lang='chinese')
+
+ # Filtering special words, cleansing punctuation marks, and filtering
out invalid tokens
+ masked_text, placeholder_map = self._word_mask(text)
+ pos_tags = self._get_valid_tokens(masked_text)
+
+ # Word segmentation
+ for word, flag in pos_tags:
+ if word in placeholder_map:
+ words.append(placeholder_map[word])
+ continue
+
+ if len(word) >= 1 and flag in self.pos_filter['english'] and
word.lower() not in en_stop_words:
+ words.append(word)
+ if re.compile('[\u4e00-\u9fff]').search(word):
+ jieba_tokens = pseg.cut(word)
+ for ch_word, ch_flag in jieba_tokens:
+ if len(ch_word) >= 1 and ch_flag in
self.pos_filter['chinese'] \
+ and ch_word not in ch_stop_words:
+ words.append(ch_word)
+ elif len(word) >= 1 and flag in self.pos_filter['english'] and
word.lower() not in en_stop_words:
+ words.append(word)
+ return words
+
+ def _build_graph(self, words):
+ unique_words = list(set(words))
+ name_to_idx = {word: idx for idx, word in enumerate(unique_words)}
+ edge_weights = defaultdict(int)
+ for i, word1 in enumerate(words):
+ for j in range(i + 1, min(i + self.window + 1, len(words))):
+ word2 = words[j]
+ if word1 != word2:
+ pair = tuple(sorted((word1, word2)))
+ edge_weights[pair] += 1
+
+ graph = ig.Graph(n=len(unique_words), directed=False)
+ graph.vs['name'] = unique_words
+ edges_idx = [(name_to_idx[a], name_to_idx[b]) for (a, b) in
edge_weights.keys()]
+ graph.add_edges(edges_idx)
+ graph.es['weight'] = list(edge_weights.values())
+ self.graph = graph
+
+ def _rank_nodes(self):
+ if not self.graph or self.graph.vcount() == 0:
+ return {}
+
+ pagerank_scores = self.graph.pagerank(directed=False, damping=0.85,
weights='weight')
+ if max(pagerank_scores) > 0:
+ pagerank_scores = [scores/max(pagerank_scores) for scores in
pagerank_scores]
+ node_names = self.graph.vs['name']
+ return dict(zip(node_names, pagerank_scores))
+
+ def extract_keywords(self, text) -> Dict[str, float]:
+ if not NLTKHelper().check_nltk_data():
+ log.error("NLTK data check failed, cannot proceed with keyword
extraction")
+ return {}
+
+ words = self._multi_preprocess(text)
+ if not words:
+ return {}
+
+ # PageRank
+ unique_words = list(dict.fromkeys(words))
+ ranks = dict(zip(unique_words, [0] * len(unique_words)))
+ if len(unique_words) > 1:
+ self._build_graph(words)
+ if not self.graph or self.graph.vcount() == 0:
+ return {}
+ ranks = self._rank_nodes()
+ return ranks
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
index 895a3795..a873e19a 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py
@@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, List
import jieba
+from hugegraph_llm.config import llm_settings
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
@@ -31,11 +32,10 @@ class WordExtract:
self,
text: Optional[str] = None,
llm: Optional[BaseLLM] = None,
- language: str = "english",
):
self._llm = llm
self._query = text
- self._language = language.lower()
+ self._language = llm_settings.language.lower()
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._query is None:
@@ -48,10 +48,8 @@ class WordExtract:
self._llm = LLMs().get_extract_llm()
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
- if isinstance(context.get("language"), str):
- self._language = context["language"].lower()
- else:
- context["language"] = self._language
+ # 未传入值或者其他值,默认使用英文
+ self._language = "chinese" if self._language == "cn" else "english"
keywords = jieba.lcut(self._query)
keywords = self._filter_keywords(keywords, lowercase=False)
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 65c95db5..be0ac0ca 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -16,8 +16,9 @@
# under the License.
-from typing import Dict, Any, Optional, List, Literal
+from typing import Any, Dict, List, Literal, Optional
+from hugegraph_llm.config import huge_settings, prompt
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.models.embeddings.init_embedding import Embeddings
from hugegraph_llm.models.llms.base import BaseLLM
@@ -31,8 +32,7 @@ from hugegraph_llm.operators.index_op.semantic_id_query
import SemanticIdQuery
from hugegraph_llm.operators.index_op.vector_index_query import
VectorIndexQuery
from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize
from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract
-from hugegraph_llm.utils.decorators import log_time, log_operator_time,
record_rpm
-from hugegraph_llm.config import prompt, huge_settings
+from hugegraph_llm.utils.decorators import log_operator_time, log_time,
record_rpm
class RAGPipeline:
@@ -54,39 +54,32 @@ class RAGPipeline:
self._embedding = embedding or Embeddings().get_embedding()
self._operators: List[Any] = []
- def extract_word(self, text: Optional[str] = None, language: str =
"english"):
+ def extract_word(self, text: Optional[str] = None):
"""
Add a word extraction operator to the pipeline.
:param text: Text to extract words from.
- :param language: Language of the text.
:return: Self-instance for chaining.
"""
- self._operators.append(WordExtract(text=text, language=language))
+ self._operators.append(WordExtract(text=text))
return self
def extract_keywords(
self,
text: Optional[str] = None,
- max_keywords: int = 5,
- language: str = "english",
extract_template: Optional[str] = None,
):
"""
Add a keyword extraction operator to the pipeline.
:param text: Text to extract keywords from.
- :param max_keywords: Maximum number of keywords to extract.
- :param language: Language of the text.
:param extract_template: Template for keyword extraction.
:return: Self-instance for chaining.
"""
self._operators.append(
KeywordExtract(
text=text,
- max_keywords=max_keywords,
- language=language,
- extract_template=extract_template,
+ extract_template=extract_template
)
)
return self
diff --git
a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
index b1e3c7db..1e9ca652 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py
@@ -15,15 +15,14 @@
# specific language governing permissions and limitations
# under the License.
-
import re
import time
-from typing import Set, Dict, Any, Optional
+from typing import Any, Dict, Optional
-from hugegraph_llm.config import prompt
+from hugegraph_llm.config import prompt, llm_settings
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
-from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper
+from hugegraph_llm.operators.document_op.textrank_word_extract import
MultiLingualTextRank
from hugegraph_llm.utils.log import log
KEYWORDS_EXTRACT_TPL = prompt.keywords_extract_prompt
@@ -31,18 +30,21 @@ KEYWORDS_EXTRACT_TPL = prompt.keywords_extract_prompt
class KeywordExtract:
def __init__(
- self,
- text: Optional[str] = None,
- llm: Optional[BaseLLM] = None,
- max_keywords: int = 5,
- extract_template: Optional[str] = None,
- language: str = "english",
+ self,
+ text: Optional[str] = None,
+ llm: Optional[BaseLLM] = None,
+ max_keywords: int = 5,
+ extract_template: Optional[str] = None,
):
self._llm = llm
self._query = text
- self._language = language.lower()
+ self._language = llm_settings.language.lower()
self._max_keywords = max_keywords
self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL
+ self._extract_method = llm_settings.keyword_extract_type.lower()
+ self._textrank_model = MultiLingualTextRank(
+ keyword_num=max_keywords,
+ window_size=llm_settings.window_size)
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._query is None:
@@ -55,48 +57,128 @@ class KeywordExtract:
self._llm = LLMs().get_extract_llm()
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
- self._language = context.get("language", self._language).lower()
- self._max_keywords = context.get("max_keywords", self._max_keywords)
+ # Use English by default
+ self._language = "chinese" if self._language == "cn" else "english"
+ max_keyword_num = context.get("max_keywords", self._max_keywords)
+ try:
+ max_keyword_num = int(max_keyword_num)
+ except (TypeError, ValueError):
+ max_keyword_num = self._max_keywords
+ self._max_keywords = max(1, max_keyword_num)
+
+ method = (context.get("extract_method", self._extract_method) or
"LLM").strip().lower()
+ if method == "llm":
+ # LLM method
+ ranks = self._extract_with_llm()
+ elif method == "textrank":
+ # TextRank method
+ ranks = self._extract_with_textrank()
+ elif method == "hybrid":
+ # Hybrid method
+ ranks = self._extract_with_hybrid()
+ else:
+ log.warning("Invalid extract_method %s", method)
+ raise ValueError(f"Invalid extract_method: {method}")
+
+ keywords = [] if not ranks else sorted(ranks, key=ranks.get,
reverse=True)
+ keywords = [k.replace("'", "") for k in keywords]
+ context["keywords"] = keywords[:self._max_keywords]
+ log.info("User Query: %s\nKeywords: %s", self._query,
context["keywords"])
+
+ # extracting keywords & expanding synonyms increase the call count by 1
+ context["call_count"] = context.get("call_count", 0) + 1
+ return context
+ def _extract_with_llm(self) -> Dict[str, float]:
prompt_run = f"{self._extract_template.format(question=self._query,
max_keywords=self._max_keywords)}"
start_time = time.perf_counter()
response = self._llm.generate(prompt=prompt_run)
end_time = time.perf_counter()
- log.debug("Keyword extraction time: %.2f seconds", end_time -
start_time)
-
+ log.debug("LLM Keyword extraction time: %.2f seconds", end_time -
start_time)
keywords = self._extract_keywords_from_response(
response=response, lowercase=False, start_token="KEYWORDS:"
)
- keywords = {k.replace("'", "") for k in keywords}
- context["keywords"] = list(keywords)
- log.info("User Query: %s\nKeywords: %s", self._query,
context["keywords"])
+ return keywords
- # extracting keywords & expanding synonyms increase the call count by 1
- context["call_count"] = context.get("call_count", 0) + 1
- return context
+ def _extract_with_textrank(self) -> Dict[str, float]:
+ """ TextRank mode extraction """
+ start_time = time.perf_counter()
+ ranks = {}
+ try:
+ ranks = self._textrank_model.extract_keywords(self._query)
+ except (TypeError, ValueError) as e:
+ log.error("TextRank parameter error: %s", e)
+ except MemoryError as e:
+ log.critical("TextRank memory error (text too large?): %s", e)
+ end_time = time.perf_counter()
+ log.debug("TextRank Keyword extraction time: %.2f seconds",
+ end_time - start_time)
+ return ranks
+
+ def _extract_with_hybrid(self) -> Dict[str, float]:
+ """ Hybrid mode extraction """
+ ranks = {}
+
+ if isinstance(llm_settings.hybrid_llm_weights, float):
+ llm_weights = min(1.0, max(0.0,
float(llm_settings.hybrid_llm_weights)))
+ else:
+ llm_weights = 0.5
+
+ start_time = time.perf_counter()
+
+ llm_scores = self._extract_with_llm()
+ tr_scores = self._extract_with_textrank()
+ lr_set = set(k for k in llm_scores)
+ tr_set = set(k for k in tr_scores)
+
+ log.debug("LLM extract results: %s", llm_scores)
+ log.debug("TextRank extract results: %s", tr_scores)
+
+ union_set = lr_set | tr_set
+ for word in union_set:
+ ranks[word] = 0
+ if word in llm_scores:
+ ranks[word] += llm_scores[word] * llm_weights
+ if word in tr_scores:
+ ranks[word] += tr_scores[word] * (1-llm_weights)
+
+ end_time = time.perf_counter()
+ log.debug("Hybrid Keyword extraction time: %.2f seconds", end_time -
start_time)
+ return ranks
def _extract_keywords_from_response(
- self,
- response: str,
- lowercase: bool = True,
- start_token: str = "",
- ) -> Set[str]:
- keywords = []
+ self,
+ response: str,
+ lowercase: bool = True,
+ start_token: str = "",
+ ) -> Dict[str, float]:
+
+ results = {}
+
# use re.escape(start_token) if start_token contains special chars
like */&/^ etc.
- matches = re.findall(rf'{start_token}[^\n]+\n?', response)
+ matches = re.findall(rf'{start_token}([^\n]+\n?)', response)
for match in matches:
- match = match[len(start_token):].strip()
- keywords.extend(
- k.lower() if lowercase else k
- for k in re.split(r"[,,]+", match)
- if len(k.strip()) > 1
- )
-
- # if the keyword consists of multiple words, split into sub-words
(removing stopwords)
- results = set(keywords)
- for token in keywords:
- sub_tokens = re.findall(r"\w+", token)
- if len(sub_tokens) > 1:
- results.update(w for w in sub_tokens if w not in
NLTKHelper().stopwords(lang=self._language))
+ match = match.strip()
+ for k in re.split(r"[,,]+", match):
+ item = k.strip()
+ if not item:
+ continue
+ try:
+ parts = re.split(r"[::]", item, maxsplit=1)
+ if len(parts) != 2:
+ log.warning("Skipping malformed item: %s", item)
+ continue
+ word_raw, score_raw = parts[0].strip(), parts[1].strip()
+ if not word_raw:
+ continue
+ score_val = float(score_raw)
+ if not 0.0 <= score_val <= 1.0:
+ log.warning("Score out of range for %s: %s", word_raw,
score_val)
+ score_val = min(1.0, max(0.0, score_val))
+ word_out = word_raw.lower() if lowercase else word_raw
+ results[word_out] = score_val
+ except (ValueError, AttributeError) as e:
+ log.warning("Failed to parse item '%s': %s", item, e)
+ continue
return results
diff --git a/pyproject.toml b/pyproject.toml
index 9e1624b4..8bcf5892 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -104,6 +104,8 @@ constraint-dependencies = [
"numpy~=1.24.4",
"pandas~=2.2.3", # TODO: replace by polars(rust) in the future
"pydantic~=2.10.6",
+ "scipy~=1.15.3", # word segment need
+ "python-igraph~=0.11.9", # textrank need
# LLM dependencies
"openai~=1.61.0",
@@ -123,6 +125,7 @@ constraint-dependencies = [
"apscheduler~=3.10.4",
"litellm~=1.61.13",
+
# ML dependencies
"dgl~=2.1.0",
"ogb~=1.3.6",