yanand0909 commented on code in PR #425: URL: https://github.com/apache/flink-agents/pull/425#discussion_r2678360443
########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); + + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private String jobId; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + } + + @Override + public void switchContext(String key) { + this.key = key; + } + + private BaseVectorStore store() throws Exception { + if (vectorStore instanceof String) { + vectorStore = ctx.getResource((String) vectorStore, ResourceType.VECTOR_STORE); + } + return (BaseVectorStore) vectorStore; + } + + @Override + public MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception { + MemorySet memorySet = new MemorySet(name, itemType, capacity, strategy); + ((CollectionManageableVectorStore) this.store()) + .getOrCreateCollection( + this.nameMangling(name), + Map.of("memory_set", mapper.writeValueAsString(memorySet))); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public MemorySet getMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .getCollection(this.nameMangling(name)); + MemorySet memorySet = + mapper.readValue( + (String) collection.getMetadata().get("memory_set"), MemorySet.class); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public boolean deleteMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .deleteCollection(this.nameMangling(name)); + return collection != null; + } + + @Override + public long size(MemorySet memorySet) throws Exception { + return this.store().size(this.nameMangling(memorySet.getName())); + } + + @Override + public List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + if (ids == null || ids.isEmpty()) { + ids = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + ids.add(UUID.randomUUID().toString()); + } + } + + String timestamp = LocalDateTime.now().format(formatter); + Map<String, Object> metadata = + Map.of( + "compacted", + false, + "created_time", + timestamp, + "last_accessed_time", + timestamp); + + List<Map<String, Object>> mergedMetadatas = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.add(new HashMap<>(metadata)); + } + + if (metadatas != null && !metadatas.isEmpty()) { + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.get(i).putAll(metadatas.get(i)); + } + } + + List<Document> documents = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + documents.add( + new Document( + mapper.writeValueAsString(memoryItems.get(i)), + mergedMetadatas.get(i), + ids.get(i))); + } + + List<String> itemIds = + this.store() + .add( + documents, + this.nameMangling(memorySet.getName()), + Collections.emptyMap()); + + if (memorySet.size() >= memorySet.getCapacity()) { + if (this.asyncCompaction) { + CompletableFuture.runAsync( + () -> { + try { + compact(memorySet); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + this.workerExecutor()) + .exceptionally( + e -> { + throw new RuntimeException( + String.format( + "Compaction for %s failed", + this.nameMangling(memorySet.getName())), + e); + }); + } else { + this.compact(memorySet); Review Comment: When `asyncCompaction = false`, the entire add() call blocks during compaction this blocks the Flink operator and could cause backpressure? Should we document this clearly or make async the default? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); + + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private String jobId; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + } + + @Override + public void switchContext(String key) { + this.key = key; + } + + private BaseVectorStore store() throws Exception { + if (vectorStore instanceof String) { + vectorStore = ctx.getResource((String) vectorStore, ResourceType.VECTOR_STORE); + } + return (BaseVectorStore) vectorStore; + } + + @Override + public MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception { + MemorySet memorySet = new MemorySet(name, itemType, capacity, strategy); + ((CollectionManageableVectorStore) this.store()) + .getOrCreateCollection( + this.nameMangling(name), + Map.of("memory_set", mapper.writeValueAsString(memorySet))); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public MemorySet getMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .getCollection(this.nameMangling(name)); + MemorySet memorySet = + mapper.readValue( + (String) collection.getMetadata().get("memory_set"), MemorySet.class); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public boolean deleteMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .deleteCollection(this.nameMangling(name)); + return collection != null; + } + + @Override + public long size(MemorySet memorySet) throws Exception { + return this.store().size(this.nameMangling(memorySet.getName())); + } + + @Override + public List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + if (ids == null || ids.isEmpty()) { + ids = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + ids.add(UUID.randomUUID().toString()); + } + } + + String timestamp = LocalDateTime.now().format(formatter); + Map<String, Object> metadata = + Map.of( + "compacted", + false, + "created_time", + timestamp, + "last_accessed_time", + timestamp); + + List<Map<String, Object>> mergedMetadatas = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.add(new HashMap<>(metadata)); + } + + if (metadatas != null && !metadatas.isEmpty()) { + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.get(i).putAll(metadatas.get(i)); + } + } + + List<Document> documents = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + documents.add( + new Document( + mapper.writeValueAsString(memoryItems.get(i)), + mergedMetadatas.get(i), + ids.get(i))); + } + + List<String> itemIds = + this.store() + .add( + documents, + this.nameMangling(memorySet.getName()), + Collections.emptyMap()); + + if (memorySet.size() >= memorySet.getCapacity()) { + if (this.asyncCompaction) { + CompletableFuture.runAsync( + () -> { + try { + compact(memorySet); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + this.workerExecutor()) + .exceptionally( + e -> { + throw new RuntimeException( + String.format( + "Compaction for %s failed", + this.nameMangling(memorySet.getName())), + e); + }); + } else { + this.compact(memorySet); + } + } + + return itemIds; + } + + @Override + public List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) + throws Exception { + List<Document> documents = + this.store() + .get(ids, this.nameMangling(memorySet.getName()), Collections.emptyMap()); + return this.convertToItems(memorySet, documents); + } + + @Override + public void delete(MemorySet memorySet, @Nullable List<String> ids) throws Exception { + this.store().delete(ids, this.nameMangling(memorySet.getName()), Collections.emptyMap()); + } + + @Override + public List<MemorySetItem> search( + MemorySet memorySet, String query, int limit, Map<String, Object> extraArgs) + throws Exception { + VectorStoreQuery vectorStoreQuery = + new VectorStoreQuery( + query, limit, this.nameMangling(memorySet.getName()), extraArgs); + VectorStoreQueryResult result = this.store().query(vectorStoreQuery); + return this.convertToItems(memorySet, result.getDocuments()); + } + + private String nameMangling(String name) { + return String.join("-", this.jobId, this.key, name); + } + + private List<MemorySetItem> convertToItems(MemorySet memorySet, List<Document> documents) + throws JsonProcessingException { + List<MemorySetItem> items = new ArrayList<>(); + for (Document doc : documents) { + Map<String, Object> metadata = doc.getMetadata(); + boolean compacted = (boolean) metadata.remove("compacted"); + Object createdTime; + if (compacted) { + createdTime = + new MemorySetItem.DateTimeRange( + LocalDateTime.parse( + (String) metadata.remove("created_time_start"), formatter), + LocalDateTime.parse( + (String) metadata.remove("created_time_end"), formatter)); + } else { + createdTime = + LocalDateTime.parse((String) metadata.remove("created_time"), formatter); + } + MemorySetItem item = + new MemorySetItem( + memorySet.getName(), + doc.getId(), + memorySet.getItemType() == String.class + ? doc.getContent() + : mapper.readValue(doc.getContent(), memorySet.getItemType()), + compacted, + createdTime, + LocalDateTime.parse( + (String) metadata.remove("last_accessed_time"), formatter), + metadata); + items.add(item); + } + return items; + } + + private void compact(MemorySet memorySet) throws Exception { + CompactionStrategy strategy = memorySet.getStrategy(); + if (strategy.type() == CompactionStrategy.Type.SUMMARIZATION) { + summarize(this, memorySet, ctx, null); + } else { + throw new RuntimeException( + String.format("Unknown compaction strategy: %s", strategy.type())); + } + } + + private ExecutorService workerExecutor() { + // TODO: shutdown executor when close. + if (lazyCompactExecutor == null) { + lazyCompactExecutor = + Executors.newFixedThreadPool( + 2, Review Comment: Should we make number of threads configurable? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); + + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private String jobId; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + } + + @Override + public void switchContext(String key) { + this.key = key; + } + + private BaseVectorStore store() throws Exception { + if (vectorStore instanceof String) { + vectorStore = ctx.getResource((String) vectorStore, ResourceType.VECTOR_STORE); + } + return (BaseVectorStore) vectorStore; + } + + @Override + public MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception { + MemorySet memorySet = new MemorySet(name, itemType, capacity, strategy); + ((CollectionManageableVectorStore) this.store()) + .getOrCreateCollection( + this.nameMangling(name), + Map.of("memory_set", mapper.writeValueAsString(memorySet))); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public MemorySet getMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .getCollection(this.nameMangling(name)); + MemorySet memorySet = + mapper.readValue( + (String) collection.getMetadata().get("memory_set"), MemorySet.class); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public boolean deleteMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .deleteCollection(this.nameMangling(name)); + return collection != null; + } + + @Override + public long size(MemorySet memorySet) throws Exception { + return this.store().size(this.nameMangling(memorySet.getName())); + } + + @Override + public List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + if (ids == null || ids.isEmpty()) { + ids = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + ids.add(UUID.randomUUID().toString()); + } + } + + String timestamp = LocalDateTime.now().format(formatter); + Map<String, Object> metadata = + Map.of( + "compacted", + false, + "created_time", + timestamp, + "last_accessed_time", + timestamp); + + List<Map<String, Object>> mergedMetadatas = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.add(new HashMap<>(metadata)); + } + + if (metadatas != null && !metadatas.isEmpty()) { + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.get(i).putAll(metadatas.get(i)); + } + } + + List<Document> documents = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + documents.add( + new Document( + mapper.writeValueAsString(memoryItems.get(i)), + mergedMetadatas.get(i), + ids.get(i))); + } + + List<String> itemIds = + this.store() + .add( + documents, + this.nameMangling(memorySet.getName()), + Collections.emptyMap()); + + if (memorySet.size() >= memorySet.getCapacity()) { + if (this.asyncCompaction) { Review Comment: QQ: Could this cause race condition? If multiple threads add items simultaneously, compaction could be triggered multiple times before any completes ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); + + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private String jobId; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + } + + @Override + public void switchContext(String key) { + this.key = key; + } + + private BaseVectorStore store() throws Exception { + if (vectorStore instanceof String) { + vectorStore = ctx.getResource((String) vectorStore, ResourceType.VECTOR_STORE); + } + return (BaseVectorStore) vectorStore; + } + + @Override + public MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception { + MemorySet memorySet = new MemorySet(name, itemType, capacity, strategy); + ((CollectionManageableVectorStore) this.store()) + .getOrCreateCollection( + this.nameMangling(name), + Map.of("memory_set", mapper.writeValueAsString(memorySet))); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public MemorySet getMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .getCollection(this.nameMangling(name)); + MemorySet memorySet = + mapper.readValue( + (String) collection.getMetadata().get("memory_set"), MemorySet.class); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public boolean deleteMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .deleteCollection(this.nameMangling(name)); + return collection != null; + } + + @Override + public long size(MemorySet memorySet) throws Exception { + return this.store().size(this.nameMangling(memorySet.getName())); + } + + @Override + public List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + if (ids == null || ids.isEmpty()) { + ids = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + ids.add(UUID.randomUUID().toString()); + } + } + + String timestamp = LocalDateTime.now().format(formatter); + Map<String, Object> metadata = + Map.of( + "compacted", + false, + "created_time", + timestamp, + "last_accessed_time", + timestamp); + + List<Map<String, Object>> mergedMetadatas = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.add(new HashMap<>(metadata)); + } + + if (metadatas != null && !metadatas.isEmpty()) { + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.get(i).putAll(metadatas.get(i)); + } + } + + List<Document> documents = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + documents.add( + new Document( + mapper.writeValueAsString(memoryItems.get(i)), + mergedMetadatas.get(i), + ids.get(i))); + } + + List<String> itemIds = + this.store() + .add( + documents, + this.nameMangling(memorySet.getName()), + Collections.emptyMap()); + + if (memorySet.size() >= memorySet.getCapacity()) { + if (this.asyncCompaction) { + CompletableFuture.runAsync( + () -> { + try { + compact(memorySet); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + this.workerExecutor()) + .exceptionally( + e -> { + throw new RuntimeException( + String.format( + "Compaction for %s failed", + this.nameMangling(memorySet.getName())), + e); + }); + } else { + this.compact(memorySet); + } + } + + return itemIds; + } + + @Override + public List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) + throws Exception { + List<Document> documents = + this.store() + .get(ids, this.nameMangling(memorySet.getName()), Collections.emptyMap()); + return this.convertToItems(memorySet, documents); + } + + @Override + public void delete(MemorySet memorySet, @Nullable List<String> ids) throws Exception { + this.store().delete(ids, this.nameMangling(memorySet.getName()), Collections.emptyMap()); + } + + @Override + public List<MemorySetItem> search( + MemorySet memorySet, String query, int limit, Map<String, Object> extraArgs) + throws Exception { + VectorStoreQuery vectorStoreQuery = + new VectorStoreQuery( + query, limit, this.nameMangling(memorySet.getName()), extraArgs); + VectorStoreQueryResult result = this.store().query(vectorStoreQuery); + return this.convertToItems(memorySet, result.getDocuments()); + } + + private String nameMangling(String name) { + return String.join("-", this.jobId, this.key, name); + } + + private List<MemorySetItem> convertToItems(MemorySet memorySet, List<Document> documents) + throws JsonProcessingException { + List<MemorySetItem> items = new ArrayList<>(); + for (Document doc : documents) { + Map<String, Object> metadata = doc.getMetadata(); + boolean compacted = (boolean) metadata.remove("compacted"); + Object createdTime; + if (compacted) { + createdTime = + new MemorySetItem.DateTimeRange( + LocalDateTime.parse( + (String) metadata.remove("created_time_start"), formatter), + LocalDateTime.parse( + (String) metadata.remove("created_time_end"), formatter)); + } else { + createdTime = + LocalDateTime.parse((String) metadata.remove("created_time"), formatter); + } + MemorySetItem item = + new MemorySetItem( + memorySet.getName(), + doc.getId(), + memorySet.getItemType() == String.class + ? doc.getContent() + : mapper.readValue(doc.getContent(), memorySet.getItemType()), + compacted, + createdTime, + LocalDateTime.parse( + (String) metadata.remove("last_accessed_time"), formatter), + metadata); + items.add(item); + } + return items; + } + + private void compact(MemorySet memorySet) throws Exception { + CompactionStrategy strategy = memorySet.getStrategy(); + if (strategy.type() == CompactionStrategy.Type.SUMMARIZATION) { + summarize(this, memorySet, ctx, null); + } else { + throw new RuntimeException( + String.format("Unknown compaction strategy: %s", strategy.type())); + } + } + + private ExecutorService workerExecutor() { + // TODO: shutdown executor when close. Review Comment: Do you want to address this? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java: ########## @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.formatter; +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.mapper; + +public class CompactionFunctions { + private static final Logger LOG = LoggerFactory.getLogger(CompactionFunctions.class); + + private static Prompt DEFAULT_ANALYSIS_PROMPT = + Prompt.fromText( + "<role>\n" + + "Context Summarize Assistant\n" + + "</role>\n" + + "\n" + + "<primary_objective>\n" + + "Your sole objective in this task is to summarize the context above.\n" + + "</primary_objective>\n" + + "\n" + + "<objective_information>\n" + + "You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you should extract important topics. Notice,\n" + + "**The topics must no more than {limit}**. Afterwards, you should generate summarization for each topic, and record indices of the messages the summary was derived from. " + + "**There are {count} messages totally, indexed from 0 to {end}, DO NOT omit any message, even if irrelevant**. The messages involved in each topic must not overlap, and their union must equal the entire set of messages.\n" + + "</objective_information>\n" + + "\n" + + "<output_example>\n" + + "You must always respond with valid json format in this format:\n" + + "{\"topic1\": {\"summarization\": \"User ask what is 1 * 2, and the result is 3.\", \"messages\": [0,1,2,3]},\n" + + " ...\n" + + " \"topic4\": {\"summarization\": \"User ask what's the weather tomorrow, llm use the search_weather, and the answer is snow.\", \"messages\": [9,10,11,12]}\n" + + "}\n" + + "</output_example>"); + + /** + * Generate summarization of the items in the memory set. + * + * <p>This method will add the summarization to memory set, and delete original items involved + * in summarization. + * + * @param ltm The long term memory the memory set belongs to. + * @param memorySet The memory set to be summarized. + * @param ctx The runner context used to retrieve needed resources. + * @param ids The ids of items to be summarized. If not provided, all items will be involved in + * summarization. Optional. + */ + @SuppressWarnings("unchecked") + public static void summarize( + BaseLongTermMemory ltm, + MemorySet memorySet, + RunnerContext ctx, + @Nullable List<String> ids) + throws Exception { + SummarizationStrategy strategy = (SummarizationStrategy) memorySet.getStrategy(); + + List<MemorySetItem> items = ltm.get(memorySet, ids); + ChatMessage response = generateSummarization(items, memorySet.getItemType(), strategy, ctx); + + LOG.debug("Items to be summarized: {}\n, Summarization: {}", items, response.getContent()); + + Map<String, Map<String, Object>> topics = + mapper.readValue(response.getContent(), Map.class); + + for (Map<String, Object> topic : topics.values()) { + String summarization = (String) topic.get("summarization"); + List<Integer> indices = (List<Integer>) topic.get("messages"); + + if (strategy.getLimit() == 1) { + indices = IntStream.range(0, items.size()).boxed().collect(Collectors.toList()); + } + + Object item; + if (memorySet.getItemType() == ChatMessage.class) { + item = new ChatMessage(MessageRole.USER, summarization); + } else { + item = summarization; + } + + List<LocalDateTime> created_times = new ArrayList<>(); + List<LocalDateTime> lastAccessedTimes = new ArrayList<>(); + List<String> itemIds = new ArrayList<>(); + for (int index : indices) { + if (items.get(index).isCompacted()) { + created_times.add( + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()) + .getStart()); + created_times.add( + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()) + .getEnd()); + } else { + created_times.add((LocalDateTime) items.get(index).getCreatedTime()); + } + + lastAccessedTimes.add(items.get(index).getLastAccessedTime()); + + itemIds.add(items.get(index).getId()); + } + + Preconditions.checkArgument(!created_times.isEmpty()); + String start = + created_times.stream().min(LocalDateTime::compareTo).get().format(formatter); + String end = Review Comment: Similar for end and lastAccessedTime? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java: ########## @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.formatter; +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.mapper; + +public class CompactionFunctions { + private static final Logger LOG = LoggerFactory.getLogger(CompactionFunctions.class); + + private static Prompt DEFAULT_ANALYSIS_PROMPT = + Prompt.fromText( + "<role>\n" + + "Context Summarize Assistant\n" + + "</role>\n" + + "\n" + + "<primary_objective>\n" + + "Your sole objective in this task is to summarize the context above.\n" + + "</primary_objective>\n" + + "\n" + + "<objective_information>\n" + + "You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you should extract important topics. Notice,\n" + + "**The topics must no more than {limit}**. Afterwards, you should generate summarization for each topic, and record indices of the messages the summary was derived from. " + + "**There are {count} messages totally, indexed from 0 to {end}, DO NOT omit any message, even if irrelevant**. The messages involved in each topic must not overlap, and their union must equal the entire set of messages.\n" + + "</objective_information>\n" + + "\n" + + "<output_example>\n" + + "You must always respond with valid json format in this format:\n" + + "{\"topic1\": {\"summarization\": \"User ask what is 1 * 2, and the result is 3.\", \"messages\": [0,1,2,3]},\n" + + " ...\n" + + " \"topic4\": {\"summarization\": \"User ask what's the weather tomorrow, llm use the search_weather, and the answer is snow.\", \"messages\": [9,10,11,12]}\n" + + "}\n" + + "</output_example>"); + + /** + * Generate summarization of the items in the memory set. + * + * <p>This method will add the summarization to memory set, and delete original items involved + * in summarization. + * + * @param ltm The long term memory the memory set belongs to. + * @param memorySet The memory set to be summarized. + * @param ctx The runner context used to retrieve needed resources. + * @param ids The ids of items to be summarized. If not provided, all items will be involved in + * summarization. Optional. + */ + @SuppressWarnings("unchecked") + public static void summarize( + BaseLongTermMemory ltm, + MemorySet memorySet, + RunnerContext ctx, + @Nullable List<String> ids) + throws Exception { + SummarizationStrategy strategy = (SummarizationStrategy) memorySet.getStrategy(); + + List<MemorySetItem> items = ltm.get(memorySet, ids); + ChatMessage response = generateSummarization(items, memorySet.getItemType(), strategy, ctx); + + LOG.debug("Items to be summarized: {}\n, Summarization: {}", items, response.getContent()); + + Map<String, Map<String, Object>> topics = + mapper.readValue(response.getContent(), Map.class); + + for (Map<String, Object> topic : topics.values()) { + String summarization = (String) topic.get("summarization"); + List<Integer> indices = (List<Integer>) topic.get("messages"); + + if (strategy.getLimit() == 1) { + indices = IntStream.range(0, items.size()).boxed().collect(Collectors.toList()); + } + + Object item; + if (memorySet.getItemType() == ChatMessage.class) { + item = new ChatMessage(MessageRole.USER, summarization); + } else { + item = summarization; + } + + List<LocalDateTime> created_times = new ArrayList<>(); + List<LocalDateTime> lastAccessedTimes = new ArrayList<>(); + List<String> itemIds = new ArrayList<>(); + for (int index : indices) { + if (items.get(index).isCompacted()) { + created_times.add( + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()) + .getStart()); + created_times.add( + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()) + .getEnd()); + } else { + created_times.add((LocalDateTime) items.get(index).getCreatedTime()); + } + + lastAccessedTimes.add(items.get(index).getLastAccessedTime()); + + itemIds.add(items.get(index).getId()); + } + + Preconditions.checkArgument(!created_times.isEmpty()); + String start = + created_times.stream().min(LocalDateTime::compareTo).get().format(formatter); + String end = + created_times.stream().max(LocalDateTime::compareTo).get().format(formatter); + + Preconditions.checkArgument(!lastAccessedTimes.isEmpty()); + String lastAccessedTime = + lastAccessedTimes.stream() + .max(LocalDateTime::compareTo) + .get() + .format(formatter); + + ltm.delete(memorySet, itemIds); + + ltm.add( + memorySet, + Collections.singletonList(item), + null, + Collections.singletonList( + Map.of( + "compacted", + true, + "created_time_start", + start, + "created_time_end", + end, + "last_accessed_time", + lastAccessedTime))); + System.out.print("summarize"); Review Comment: Should we remove this? ########## e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryTest.java: ########## @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.configuration.AgentConfigOptions; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup; +import org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.connector.file.src.reader.TextLineInputFormat; +import org.apache.flink.core.fs.Path; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Test for {@link VectorStoreLongTermMemory} + * + * <p>We use {@link ElasticsearchVectorStore} as the backend of Long-Term Memory, so need setup + * Elasticsearch server to run this test. Look <a + * href="https://www.elastic.co/docs/deploy-manage/deploy/self-managed/install-elasticsearch-docker-basic">Start + * a single-node cluster in Docker</a> for details. + * + * <p>For {@link ElasticsearchVectorStore} doesn't support security check yet, when start the + * container, should add "-e xpack.security.enabled=false" option. + */ +@Disabled("Should setup Elasticsearch server.") +public class VectorStoreLongTermMemoryTest { + private static final String NAME = "chat-history"; + private final VectorStoreLongTermMemory ltm; + private MemorySet memorySet; + private List<ChatMessage> messages; + + public static Resource getResource(String name, ResourceType type) { + if (type == ResourceType.CHAT_MODEL_CONNECTION) { + return new OllamaChatModelConnection( + ResourceDescriptor.Builder.newBuilder(OllamaChatModelConnection.class.getName()) + .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else if (type == ResourceType.CHAT_MODEL) { + return new OllamaChatModelSetup( + ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "ollama-connection") + .addInitialArgument("model", "qwen3:8b") + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else if (type == ResourceType.EMBEDDING_MODEL_CONNECTION) { + return new OllamaEmbeddingModelConnection( + ResourceDescriptor.Builder.newBuilder( + OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 120) + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else if (type == ResourceType.EMBEDDING_MODEL) { + return new OllamaEmbeddingModelSetup( + ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "embed-connection") + .addInitialArgument("model", "nomic-embed-text") + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else { + return new ElasticsearchVectorStore( + ResourceDescriptor.Builder.newBuilder(ElasticsearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embed-setup") + .addInitialArgument("host", "localhost:9200") + .addInitialArgument("dims", 768) + .build(), + VectorStoreLongTermMemoryTest::getResource); + } + } + + public VectorStoreLongTermMemoryTest() throws Exception { + RunnerContext ctx = Mockito.mock(RunnerContext.class); + + AgentConfiguration config = new AgentConfiguration(); + config.set(LongTermMemoryOptions.ASYNC_COMPACTION, false); + Mockito.when(ctx.getConfig()).thenReturn(config); + + Mockito.when(ctx.getResource("ollama-connection", ResourceType.CHAT_MODEL_CONNECTION)) + .thenReturn(getResource("ollama-connection", ResourceType.CHAT_MODEL_CONNECTION)); + + Mockito.when(ctx.getResource("ollama-setup", ResourceType.CHAT_MODEL)) + .thenReturn(getResource("ollama-setup", ResourceType.CHAT_MODEL)); + + Mockito.when(ctx.getResource("embed-connection", ResourceType.EMBEDDING_MODEL_CONNECTION)) + .thenReturn( + getResource("embed-connection", ResourceType.EMBEDDING_MODEL_CONNECTION)); + + Mockito.when(ctx.getResource("embed-setup", ResourceType.EMBEDDING_MODEL)) + .thenReturn(getResource("embed-setup", ResourceType.EMBEDDING_MODEL)); + + Mockito.when(ctx.getResource("vector-store", ResourceType.VECTOR_STORE)) + .thenReturn(getResource("vector-store", ResourceType.VECTOR_STORE)); + + ltm = new VectorStoreLongTermMemory(ctx, "vector-store", "job-0001", "0001"); + } + + @BeforeEach + public void prepare(TestInfo info) throws Exception { + messages = new ArrayList<>(); + if (info.getTags().contains("skipBeforeEach")) { + return; + } + memorySet = + ltm.getOrCreateMemorySet( + NAME, + ChatMessage.class, + 100, + new SummarizationStrategy("ollama-setup", null, 1)); + for (int i = 0; i < 10; i++) { + messages.add( + new ChatMessage( + MessageRole.USER, String.format("This is the no.%s message", i))); + } + memorySet.add(messages, null, null); + } + + @AfterEach + public void cleanUp(TestInfo info) throws Exception { + if (info.getTags().contains("skipAfterEach")) { + return; + } + ltm.deleteMemorySet(NAME); + } + + @Test + public void testGetMemorySet() throws Exception { + MemorySet retrieved = ltm.getMemorySet(memorySet.getName()); + + Assertions.assertEquals(memorySet, retrieved); + } + + @Test + public void testAddAndGet() throws Exception { + List<MemorySetItem> items = memorySet.get(null); + List<ChatMessage> retrieved = + items.stream().map(x -> (ChatMessage) x.getValue()).collect(Collectors.toList()); + Assertions.assertEquals(messages, retrieved); + } + + @Test + public void testSearch() throws Exception { + List<MemorySetItem> items = memorySet.search("The no.5 message", 1, Collections.emptyMap()); + + List<ChatMessage> retrieved = + items.stream().map(x -> (ChatMessage) x.getValue()).collect(Collectors.toList()); + + Assertions.assertEquals(1, retrieved.size()); + Assertions.assertEquals(messages.get(5), retrieved.get(0)); + } + + @Test + @Tag("skipBeforeEach") + public void testCompact() throws Exception { + memorySet = + ltm.getOrCreateMemorySet( + NAME, + ChatMessage.class, + 8, + new SummarizationStrategy("ollama-setup", null, 2)); + messages.add(ChatMessage.user("What is flink?")); + messages.add( + ChatMessage.assistant( + "Apache Flink is a framework and distributed processing engine for stateful computations over unbounded and bounded data streams. Flink has been designed to run in all common cluster environments, perform computations at in-memory speed and at any scale.")); + messages.add(ChatMessage.user("What is flink agents?")); + messages.add( + ChatMessage.assistant( + "Apache Flink Agents is a brand-new sub-project from the Apache Flink community, providing an open-source framework for building event-driven streaming agents.")); + messages.add(ChatMessage.user("What's the whether tomorrow in london?")); + messages.add( + ChatMessage.assistant( + "", + Collections.singletonList( + Map.of( + "id", + "186780f8-c79d-4159-83e3-f65859835b14", + "type", + "function", + "function", + Map.of( + "name", + "get_weather", + "arguments", + Map.of( + "position", + "london", + "time", + "tomorrow")))))); + messages.add(ChatMessage.tool("snow")); + messages.add(ChatMessage.assistant("Tomorrow weather for london is snow.")); + memorySet.add(messages, null, null); + + List<MemorySetItem> items = memorySet.get(null); + List<ChatMessage> retrieved = + items.stream().map(x -> (ChatMessage) x.getValue()).collect(Collectors.toList()); + + Assertions.assertEquals(2, items.size()); + Assertions.assertTrue(items.get(0).isCompacted()); + Assertions.assertTrue(items.get(1).isCompacted()); + Assertions.assertInstanceOf( + MemorySetItem.DateTimeRange.class, items.get(0).getCreatedTime()); + Assertions.assertInstanceOf( + MemorySetItem.DateTimeRange.class, items.get(1).getCreatedTime()); + Assertions.assertEquals(2, memorySet.size()); + } + + @Test + @Tag("skipBeforeEach") + @Tag("skipAfterEach") + public void testUsingLtmInAction() throws Exception { + ElasticsearchVectorStore es = + new ElasticsearchVectorStore( + ResourceDescriptor.Builder.newBuilder( + ElasticsearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embed-setup") + .addInitialArgument("host", "localhost:9200") + .addInitialArgument("dims", 768) + .build(), + VectorStoreLongTermMemoryTest::getResource); + try { + // Set up the Flink streaming environment and the Agents execution environment. + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + agentsEnv.getConfig().set(AgentConfigOptions.JOB_IDENTIFIER, "ltm_test_job"); + agentsEnv + .getConfig() + .set( + LongTermMemoryOptions.BACKEND, + LongTermMemoryOptions.LongTermMemoryBackend.EXTERNAL_VECTOR_STORE); + agentsEnv + .getConfig() + .set(LongTermMemoryOptions.EXTERNAL_VECTOR_STORE_NAME, "vectorStore"); + agentsEnv.getConfig().set(LongTermMemoryOptions.ASYNC_COMPACTION, true); + + DataStream<String> inputStream = + env.fromSource( + FileSource.forRecordStreamFormat( + new TextLineInputFormat(), + new Path( + Objects.requireNonNull( + this.getClass() + .getClassLoader() + .getResource( + "input_data.txt")) + .getPath())) + .build(), + WatermarkStrategy.noWatermarks(), + "ltm-test-agent"); + DataStream<VectorStoreLongTermMemoryAgent.ProductReview> reviewDataStream = + inputStream.map( + x -> + VectorStoreLongTermMemoryAgent.mapper.readValue( + x, VectorStoreLongTermMemoryAgent.ProductReview.class)); + + // Use the ReviewAnalysisAgent to analyze each product review. + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + reviewDataStream, + VectorStoreLongTermMemoryAgent.ProductReview::getId) + .apply(new VectorStoreLongTermMemoryAgent()) + .toDataStream(); + + // Print the analysis results to stdout. + outputStream.print(); + + // Execute the Flink pipeline. + agentsEnv.execute(); + + // check async compaction + System.out.print(es.get(null, "ltm_test_job-2-test-ltm", Collections.emptyMap())); Review Comment: Was this for debugging? Remove this? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); + + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private String jobId; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + } + + @Override + public void switchContext(String key) { + this.key = key; + } + + private BaseVectorStore store() throws Exception { + if (vectorStore instanceof String) { + vectorStore = ctx.getResource((String) vectorStore, ResourceType.VECTOR_STORE); + } + return (BaseVectorStore) vectorStore; + } + + @Override + public MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception { + MemorySet memorySet = new MemorySet(name, itemType, capacity, strategy); + ((CollectionManageableVectorStore) this.store()) + .getOrCreateCollection( + this.nameMangling(name), + Map.of("memory_set", mapper.writeValueAsString(memorySet))); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public MemorySet getMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .getCollection(this.nameMangling(name)); + MemorySet memorySet = + mapper.readValue( + (String) collection.getMetadata().get("memory_set"), MemorySet.class); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public boolean deleteMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .deleteCollection(this.nameMangling(name)); + return collection != null; + } + + @Override + public long size(MemorySet memorySet) throws Exception { + return this.store().size(this.nameMangling(memorySet.getName())); + } + + @Override + public List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + if (ids == null || ids.isEmpty()) { + ids = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + ids.add(UUID.randomUUID().toString()); + } + } + + String timestamp = LocalDateTime.now().format(formatter); + Map<String, Object> metadata = + Map.of( + "compacted", + false, + "created_time", + timestamp, + "last_accessed_time", + timestamp); + + List<Map<String, Object>> mergedMetadatas = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.add(new HashMap<>(metadata)); + } + + if (metadatas != null && !metadatas.isEmpty()) { + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.get(i).putAll(metadatas.get(i)); + } + } + + List<Document> documents = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + documents.add( + new Document( + mapper.writeValueAsString(memoryItems.get(i)), + mergedMetadatas.get(i), + ids.get(i))); + } + + List<String> itemIds = + this.store() + .add( + documents, + this.nameMangling(memorySet.getName()), + Collections.emptyMap()); + + if (memorySet.size() >= memorySet.getCapacity()) { + if (this.asyncCompaction) { + CompletableFuture.runAsync( + () -> { + try { + compact(memorySet); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + this.workerExecutor()) + .exceptionally( + e -> { + throw new RuntimeException( + String.format( + "Compaction for %s failed", + this.nameMangling(memorySet.getName())), + e); + }); + } else { + this.compact(memorySet); + } + } + + return itemIds; + } + + @Override + public List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) Review Comment: Add pagination parameters or a separate `getAll()` method with warnings since for a large memory set this could cause OOM? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java: ########## @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.formatter; +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.mapper; + +public class CompactionFunctions { + private static final Logger LOG = LoggerFactory.getLogger(CompactionFunctions.class); + + private static Prompt DEFAULT_ANALYSIS_PROMPT = + Prompt.fromText( + "<role>\n" + + "Context Summarize Assistant\n" + + "</role>\n" + + "\n" + + "<primary_objective>\n" + + "Your sole objective in this task is to summarize the context above.\n" + + "</primary_objective>\n" + + "\n" + + "<objective_information>\n" + + "You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you should extract important topics. Notice,\n" + + "**The topics must no more than {limit}**. Afterwards, you should generate summarization for each topic, and record indices of the messages the summary was derived from. " + + "**There are {count} messages totally, indexed from 0 to {end}, DO NOT omit any message, even if irrelevant**. The messages involved in each topic must not overlap, and their union must equal the entire set of messages.\n" + + "</objective_information>\n" + + "\n" + + "<output_example>\n" + + "You must always respond with valid json format in this format:\n" + + "{\"topic1\": {\"summarization\": \"User ask what is 1 * 2, and the result is 3.\", \"messages\": [0,1,2,3]},\n" + + " ...\n" + + " \"topic4\": {\"summarization\": \"User ask what's the weather tomorrow, llm use the search_weather, and the answer is snow.\", \"messages\": [9,10,11,12]}\n" + + "}\n" + + "</output_example>"); + + /** + * Generate summarization of the items in the memory set. + * + * <p>This method will add the summarization to memory set, and delete original items involved + * in summarization. + * + * @param ltm The long term memory the memory set belongs to. + * @param memorySet The memory set to be summarized. + * @param ctx The runner context used to retrieve needed resources. + * @param ids The ids of items to be summarized. If not provided, all items will be involved in + * summarization. Optional. + */ + @SuppressWarnings("unchecked") + public static void summarize( + BaseLongTermMemory ltm, + MemorySet memorySet, + RunnerContext ctx, + @Nullable List<String> ids) + throws Exception { + SummarizationStrategy strategy = (SummarizationStrategy) memorySet.getStrategy(); + + List<MemorySetItem> items = ltm.get(memorySet, ids); + ChatMessage response = generateSummarization(items, memorySet.getItemType(), strategy, ctx); + + LOG.debug("Items to be summarized: {}\n, Summarization: {}", items, response.getContent()); + + Map<String, Map<String, Object>> topics = + mapper.readValue(response.getContent(), Map.class); + + for (Map<String, Object> topic : topics.values()) { + String summarization = (String) topic.get("summarization"); + List<Integer> indices = (List<Integer>) topic.get("messages"); + + if (strategy.getLimit() == 1) { + indices = IntStream.range(0, items.size()).boxed().collect(Collectors.toList()); + } + + Object item; + if (memorySet.getItemType() == ChatMessage.class) { + item = new ChatMessage(MessageRole.USER, summarization); + } else { + item = summarization; + } + + List<LocalDateTime> created_times = new ArrayList<>(); + List<LocalDateTime> lastAccessedTimes = new ArrayList<>(); + List<String> itemIds = new ArrayList<>(); + for (int index : indices) { + if (items.get(index).isCompacted()) { + created_times.add( + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()) + .getStart()); + created_times.add( + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()) + .getEnd()); + } else { + created_times.add((LocalDateTime) items.get(index).getCreatedTime()); + } + + lastAccessedTimes.add(items.get(index).getLastAccessedTime()); + + itemIds.add(items.get(index).getId()); + } + + Preconditions.checkArgument(!created_times.isEmpty()); + String start = Review Comment: Can we calculate global min at line #124 instead of adding to list and iterating again? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
