mohamedawnallah commented on code in PR #35216:
URL: https://github.com/apache/beam/pull/35216#discussion_r2147080326


##########
sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py:
##########
@@ -0,0 +1,425 @@
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+from enum import Enum
+
+from apache_beam.ml.rag.types import Chunk
+from apache_beam.ml.rag.types import Embedding
+from apache_beam.transforms.enrichment import EnrichmentSourceHandler
+from dataclasses import field
+from pymilvus import MilvusClient, AnnSearchRequest, SearchResult, Hits, Hit
+from google.protobuf.json_format import MessageToDict
+from collections.abc import Sequence
+
+
+class SearchStrategy(Enum):
+  HYBRID = "hybrid"  # Combined vector and keyword search
+  VECTOR = "vector"  # Vector similarity search only
+  KEYWORD = "keyword"  # Keyword/text search only
+
+
+class KeywordSearchMetrics(Enum):
+  """Metrics for keyword search."""
+  BM25 = "BM25"  # BM25 ranking algorithm for text relevance
+
+
+class VectorSearchMetrics(Enum):
+  """Metrics for vector search."""
+  COSINE = "COSINE"  # Cosine similarity (1 = identical, 0 = orthogonal)
+  L2 = "L2"  # Euclidean distance (smaller = more similar)
+  IP = "IP"  # Inner product (larger = more similar)
+
+
+class MilvusBaseRanker:
+  def __int__(self):
+    return
+
+  def dict(self):
+    return {}
+
+  def __str__(self):
+    return self.dict().__str__()
+
+
+@dataclass
+class MilvusConnectionParameters:
+  # URI endpoint for connecting to Milvus server.
+  # Format: "http(s)://hostname:port".
+  uri: str
+
+  # Username for authentication.
+  # Required if not using token authentication.
+  user: str = field(default_factory=str)
+
+  # Password for authentication.
+  # Required if not using token authentication.
+  password: str = field(default_factory=str)
+
+  # Database ID to connect to.
+  # Specifies which Milvus database to use.
+  db_id: str = "default"
+
+  # Authentication token.
+  # Alternative to username/password authentication.
+  token: str = field(default_factory=str)
+
+  # Connection timeout in seconds.
+  # If None, the client's default timeout is used.
+  timeout: Optional[float] = None
+
+  def __post_init__(self):
+    if not self.uri:
+      raise ValueError("URI must be provided for Milvus connection")
+
+
+@dataclass
+class BaseSearchParameters:
+  """Parameters for base (vector or keyword) search."""
+  # Boolean expression string for filtering search results.
+  # Example: 'price <= 1000 AND category == "electronics"'.
+  filter: str = field(default_factory=str)
+
+  # Maximum number of results to return per query.
+  # Must be a positive integer.
+  limit: int = 3
+
+  # Additional search parameters specific to the search type.
+  search_params: Dict[str, Any] = field(default_factory=dict)
+
+  # Field name containing the vector or text to search.
+  # Required for both vector and keyword search.
+  anns_field: Optional[str] = None
+
+  # Consistency level for read operations
+  # Options: "Strong", "Session", "Bounded", "Eventually".
+  consistency_level: Optional[str] = None
+
+  def __post_init__(self):
+    if self.limit <= 0:
+      raise ValueError(f"Search limit must be positive, got {self.limit}")
+
+
+@dataclass
+class VectorSearchParameters(BaseSearchParameters):
+  """Parameters for vector search."""
+  # Inherits all fields from BaseSearchParameters.
+  # Can add vector-specific parameters here.
+
+
+@dataclass
+class KeywordSearchParameters(BaseSearchParameters):
+  """Parameters for keyword search."""
+  # Inherits all fields from BaseSearchParameters.
+  # Can add keyword-specific parameters here.
+
+
+@dataclass
+class HybridSearchParameters:
+  """Parameters for hybrid (vector + keyword) search."""
+  # Ranker for combining vector and keyword search results.
+  # Example: RRFRanker(weight_vector=0.6, weight_keyword=0.4).
+  ranker: MilvusBaseRanker
+
+  # Maximum number of results to return per query
+  # Must be a positive integer.
+  limit: int = 3
+
+  def __post_init__(self):
+    if not self.ranker:
+      raise ValueError("Ranker must be provided for hybrid search")
+
+    if self.limit <= 0:
+      raise ValueError(f"Search limit must be positive, got {self.limit}")
+
+
+@dataclass
+class MilvusSearchParameters:
+  """Parameters configuring Milvus vector/keyword/hybrid search operations."""
+  # Name of the collection to search in.
+  # Must be an existing collection in the Milvus database.
+  collection_name: str
+
+  # Type of search to perform (VECTOR, KEYWORD, or HYBRID).
+  # Specifies the search approach that determines which parameters and Milvus
+  # APIs will be utilized.
+  search_strategy: SearchStrategy
+
+  # Parameters for vector search.
+  # Required when search_strategy is VECTOR or HYBRID.
+  vector: Optional[VectorSearchParameters] = None
+
+  # Parameters for keyword search.
+  # Required when search_strategy is KEYWORD or HYBRID.
+  keyword: Optional[KeywordSearchParameters] = None
+
+  # Parameters for hybrid search.
+  # Required when search_strategy is HYBRID.
+  hybrid: Optional[HybridSearchParameters] = None

Review Comment:
   > Rather than structuring this as 4 parameters, could we do a single 
parameter:
   
   That would make the API interface and validation code greatly simplified! 
Thanks 🙏 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@beam.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to