This is an automated email from the ASF dual-hosted git repository. pankaj pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 65020ee66a Add utils methods in pinecone provider (#35502) 65020ee66a is described below commit 65020ee66afa803f9bda226f176233e47b59a8d0 Author: Utkarsh Sharma <utkarshar...@gmail.com> AuthorDate: Tue Nov 7 20:09:46 2023 +0530 Add utils methods in pinecone provider (#35502) * Add missing methods to pinecone provider * Fix static * Add words to spelling-wordlist.txt * Update airflow/providers/pinecone/hooks/pinecone.py Co-authored-by: Pankaj Singh <98807258+pankajas...@users.noreply.github.com> * Update airflow/providers/pinecone/hooks/pinecone.py Co-authored-by: Pankaj Singh <98807258+pankajas...@users.noreply.github.com> --------- Co-authored-by: Pankaj Singh <98807258+pankajas...@users.noreply.github.com> --- airflow/providers/pinecone/hooks/pinecone.py | 219 +++++++++++++++++++++++- docs/spelling_wordlist.txt | 2 + tests/providers/pinecone/hooks/test_pinecone.py | 89 ++++++++++ 3 files changed, 308 insertions(+), 2 deletions(-) diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index 92fd620f76..9f6054ffe1 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -18,6 +18,7 @@ """Hook for Pinecone.""" from __future__ import annotations +import itertools from typing import TYPE_CHECKING, Any import pinecone @@ -25,7 +26,8 @@ import pinecone from airflow.hooks.base import BaseHook if TYPE_CHECKING: - from pinecone.core.client.models import UpsertResponse + from pinecone.core.client.model.sparse_values import SparseValues + from pinecone.core.client.models import DescribeIndexStatsResponse, QueryResponse, UpsertResponse class PineconeHook(BaseHook): @@ -86,11 +88,16 @@ class PineconeHook(BaseHook): def test_connection(self) -> tuple[bool, str]: try: - pinecone.list_indexes() + self.list_indexes() return True, "Connection established" except Exception as e: return False, str(e) + @staticmethod + def list_indexes() -> Any: + """Retrieve a list of all indexes in your project.""" + return pinecone.list_indexes() + @staticmethod def upsert( index_name: str, @@ -126,3 +133,211 @@ class PineconeHook(BaseHook): show_progress=show_progress, **kwargs, ) + + @staticmethod + def create_index( + index_name: str, + dimension: int, + index_type: str | None = "approximated", + metric: str | None = "cosine", + replicas: int | None = 1, + shards: int | None = 1, + pods: int | None = 1, + pod_type: str | None = "p1", + index_config: dict[str, str] | None = None, + metadata_config: dict[str, str] | None = None, + source_collection: str | None = "", + timeout: int | None = None, + ) -> None: + """ + Create a new index. + + .. seealso:: https://docs.pinecone.io/reference/create_index/ + + :param index_name: The name of the index to create. + :param dimension: the dimension of vectors that would be inserted in the index + :param index_type: type of index, one of {"approximated", "exact"}, defaults to "approximated". + :param metric: type of metric used in the vector index, one of {"cosine", "dotproduct", "euclidean"} + :param replicas: the number of replicas, defaults to 1. + :param shards: the number of shards per index, defaults to 1. + :param pods: Total number of pods to be used by the index. pods = shard*replicas + :param pod_type: the pod type to be used for the index. can be one of p1 or s1. + :param index_config: Advanced configuration options for the index + :param metadata_config: Configuration related to the metadata index + :param source_collection: Collection name to create the index from + :param timeout: Timeout for wait until index gets ready. + """ + pinecone.create_index( + name=index_name, + timeout=timeout, + index_type=index_type, + dimension=dimension, + metric=metric, + pods=pods, + replicas=replicas, + shards=shards, + pod_type=pod_type, + metadata_config=metadata_config, + source_collection=source_collection, + index_config=index_config, + ) + + @staticmethod + def describe_index(index_name: str) -> Any: + """ + Retrieve information about a specific index. + + :param index_name: The name of the index to describe. + """ + return pinecone.describe_index(name=index_name) + + @staticmethod + def delete_index(index_name: str, timeout: int | None = None) -> None: + """ + Delete a specific index. + + :param index_name: the name of the index. + :param timeout: Timeout for wait until index gets ready. + """ + pinecone.delete_index(name=index_name, timeout=timeout) + + @staticmethod + def configure_index(index_name: str, replicas: int | None = None, pod_type: str | None = "") -> None: + """ + Changes current configuration of the index. + + :param index_name: The name of the index to configure. + :param replicas: The new number of replicas. + :param pod_type: the new pod_type for the index. + """ + pinecone.configure_index(name=index_name, replicas=replicas, pod_type=pod_type) + + @staticmethod + def create_collection(collection_name: str, index_name: str) -> None: + """ + Create a new collection from a specified index. + + :param collection_name: The name of the collection to create. + :param index_name: The name of the source index. + """ + pinecone.create_collection(name=collection_name, source=index_name) + + @staticmethod + def delete_collection(collection_name: str) -> None: + """ + Delete a specific collection. + + :param collection_name: The name of the collection to delete. + """ + pinecone.delete_collection(collection_name) + + @staticmethod + def describe_collection(collection_name: str) -> Any: + """ + Retrieve information about a specific collection. + + :param collection_name: The name of the collection to describe. + """ + return pinecone.describe_collection(collection_name) + + @staticmethod + def list_collections() -> Any: + """Retrieve a list of all collections in the current project.""" + return pinecone.list_collections() + + @staticmethod + def query_vector( + index_name: str, + vector: list[Any], + query_id: str | None = None, + top_k: int = 10, + namespace: str | None = None, + query_filter: dict[str, str | float | int | bool | list[Any] | dict[Any, Any]] | None = None, + include_values: bool | None = None, + include_metadata: bool | None = None, + sparse_vector: SparseValues | dict[str, list[float] | list[int]] | None = None, + ) -> QueryResponse: + """ + The Query operation searches a namespace, using a query vector. + + It retrieves the ids of the most similar items in a namespace, along with their similarity scores. + API reference: https://docs.pinecone.io/reference/query + + :param index_name: The name of the index to query. + :param vector: The query vector. + :param query_id: The unique ID of the vector to be used as a query vector. + :param top_k: The number of results to return. + :param namespace: The namespace to fetch vectors from. If not specified, the default namespace is used. + :param query_filter: The filter to apply. See https://www.pinecone.io/docs/metadata-filtering/ + :param include_values: Whether to include the vector values in the result. + :param include_metadata: Indicates whether metadata is included in the response as well as the ids. + :param sparse_vector: sparse values of the query vector. Expected to be either a SparseValues object or a dict + of the form: {'indices': List[int], 'values': List[float]}, where the lists each have the same length. + """ + index = pinecone.Index(index_name) + return index.query( + vector=vector, + id=query_id, + top_k=top_k, + namespace=namespace, + filter=query_filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + ) + + @staticmethod + def _chunks(iterable: list[Any], batch_size: int = 100) -> Any: + """Helper function to break an iterable into chunks of size batch_size.""" + it = iter(iterable) + chunk = tuple(itertools.islice(it, batch_size)) + while chunk: + yield chunk + chunk = tuple(itertools.islice(it, batch_size)) + + def upsert_data_async( + self, + index_name: str, + data: list[tuple[Any]], + async_req: bool = False, + pool_threads: int | None = None, + ) -> None | list[Any]: + """ + Upserts (insert/update) data into the Pinecone index. + + :param index_name: Name of the index. + :param data: List of tuples to be upserted. Each tuple is of form (id, vector, metadata). + Metadata is optional. + :param async_req: If True, upsert operations will be asynchronous. + :param pool_threads: Number of threads for parallel upserting. If async_req is True, this must be provided. + """ + responses = [] + with pinecone.Index(index_name, pool_threads=pool_threads) as index: + if async_req and pool_threads: + async_results = [index.upsert(vectors=chunk, async_req=True) for chunk in self._chunks(data)] + responses = [async_result.get() for async_result in async_results] + else: + for chunk in self._chunks(data): + response = index.upsert(vectors=chunk) + responses.append(response) + return responses + + @staticmethod + def describe_index_stats( + index_name: str, + stats_filter: dict[str, str | float | int | bool | list[Any] | dict[Any, Any]] | None = None, + **kwargs: Any, + ) -> DescribeIndexStatsResponse: + """ + Describes the index statistics. + + Returns statistics about the index's contents. For example: The vector count per + namespace and the number of dimensions. + API reference: https://docs.pinecone.io/reference/describe_index_stats_post + + :param index_name: Name of the index. + :param stats_filter: If this parameter is present, the operation only returns statistics for vectors that + satisfy the filter. See https://www.pinecone.io/docs/metadata-filtering/ + """ + index = pinecone.Index(index_name) + return index.describe_index_stats(filter=stats_filter, **kwargs) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index bcbda170ac..85af5b4d32 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -496,6 +496,7 @@ dogstatsd donot Dont DOS'ing +dotproduct DownloadReportV downscaling downstreams @@ -1668,6 +1669,7 @@ updateonly Upsert upsert upserted +upserting upserts Upsight upstreams diff --git a/tests/providers/pinecone/hooks/test_pinecone.py b/tests/providers/pinecone/hooks/test_pinecone.py index d358ca9485..fb076cc0a3 100644 --- a/tests/providers/pinecone/hooks/test_pinecone.py +++ b/tests/providers/pinecone/hooks/test_pinecone.py @@ -42,3 +42,92 @@ class TestPineconeHook: mock_index.return_value.upsert = mock_upsert self.pinecone_hook.upsert(self.index_name, data) mock_upsert.assert_called_once_with(vectors=data, namespace="", batch_size=None, show_progress=True) + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_indexes") + def test_list_indexes(self, mock_list_indexes): + """Test that the list_indexes method of PineconeHook is called correctly.""" + self.pinecone_hook.list_indexes() + mock_list_indexes.assert_called_once() + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index") + def test_create_index(self, mock_create_index): + """Test that the create_index method of PineconeHook is called with correct arguments.""" + self.pinecone_hook.create_index(index_name=self.index_name, dimension=128) + mock_create_index.assert_called_once_with(index_name="test_index", dimension=128) + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index") + def test_describe_index(self, mock_describe_index): + """Test that the describe_index method of PineconeHook is called with correct arguments.""" + self.pinecone_hook.describe_index(index_name=self.index_name) + mock_describe_index.assert_called_once_with(index_name=self.index_name) + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.delete_index") + def test_delete_index(self, mock_delete_index): + """Test that the delete_index method of PineconeHook is called with the correct index name.""" + self.pinecone_hook.delete_index(index_name="test_index") + mock_delete_index.assert_called_once_with(index_name="test_index") + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_collection") + def test_create_collection(self, mock_create_collection): + """ + Test that the create_collection method of PineconeHook is called correctly. + """ + self.pinecone_hook.create_collection(collection_name="test_collection") + mock_create_collection.assert_called_once_with(collection_name="test_collection") + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.configure_index") + def test_configure_index(self, mock_configure_index): + """ + Test that the configure_index method of PineconeHook is called correctly. + """ + self.pinecone_hook.configure_index(index_configuration={}) + mock_configure_index.assert_called_once_with(index_configuration={}) + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_collection") + def test_describe_collection(self, mock_describe_collection): + """ + Test that the describe_collection method of PineconeHook is called correctly. + """ + self.pinecone_hook.describe_collection(collection_name="test_collection") + mock_describe_collection.assert_called_once_with(collection_name="test_collection") + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_collections") + def test_list_collections(self, mock_list_collections): + """ + Test that the list_collections method of PineconeHook is called correctly. + """ + self.pinecone_hook.list_collections() + mock_list_collections.assert_called_once() + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.query_vector") + def test_query_vector(self, mock_query_vector): + """ + Test that the query_vector method of PineconeHook is called correctly. + """ + self.pinecone_hook.query_vector(vector=[1.0, 2.0, 3.0]) + mock_query_vector.assert_called_once_with(vector=[1.0, 2.0, 3.0]) + + def test__chunks(self): + """ + Test that the _chunks method of PineconeHook behaves as expected. + """ + data = list(range(10)) + chunked_data = list(self.pinecone_hook._chunks(data, 3)) + assert chunked_data == [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)] + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.upsert_data_async") + def test_upsert_data_async_correctly(self, mock_upsert_data_async): + """ + Test that the upsert_data_async method of PineconeHook is called correctly. + """ + data = [("id1", [1.0, 2.0, 3.0], {"meta": "data"})] + self.pinecone_hook.upsert_data_async(index_name="test_index", data=data) + mock_upsert_data_async.assert_called_once_with(index_name="test_index", data=data) + + @patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index_stats") + def test_describe_index_stats(self, mock_describe_index_stats): + """ + Test that the describe_index_stats method of PineconeHook is called correctly. + """ + self.pinecone_hook.describe_index_stats(index_name="test_index") + mock_describe_index_stats.assert_called_once_with(index_name="test_index")