Skip to content

Commit

Permalink
Persist inmemory remove func (#30)
Browse files Browse the repository at this point in the history
* Refactoring PersistInMemoryEmbeddingStore with a builder

* Adding missing "remove" functions to PersistInMemoryEmbeddingStore

* feat: Add EmbeddingMatchResponse class for embedding search results
  • Loading branch information
isontheline authored Jun 4, 2024
1 parent 722fee2 commit 07484e1
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.UUID;

import org.dizitart.no2.exceptions.UniqueConstraintException;
import org.springframework.http.HttpStatus;
Expand Down Expand Up @@ -57,11 +58,19 @@ public T create(AbstractRepository<T> repository, Function<T, T> beforeSaveCallb
return entity;
}

public T get(UUID uuid, AbstractRepository<T> repository) {
return get(uuid.toString(), repository);
}

public T get(String uuid, AbstractRepository<T> repository) {
return repository.getByUuid(uuid)
.orElseThrow(() -> new ResponseStatusException(HttpStatus.NOT_FOUND, "Entity not found."));
}

public void delete(UUID uuid, AbstractRepository<T> repository) {
delete(uuid.toString(), repository);
}

public void delete(String uuid, AbstractRepository<T> repository) {
if (!repository.exists(uuid)) {
throw new ResponseStatusException(HttpStatus.NOT_FOUND, "Entity not found.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;
import java.util.Map;
import java.util.UUID;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
Expand All @@ -17,6 +18,7 @@

import ai.dragon.entity.SiloEntity;
import ai.dragon.repository.SiloRepository;
import ai.dragon.service.SiloService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Content;
Expand All @@ -30,6 +32,9 @@ public class SiloBackendApiController extends AbstractCrudBackendApiController<S
@Autowired
private SiloRepository siloRepository;

@Autowired
private SiloService siloService;

@GetMapping("/")
@ApiResponse(responseCode = "200", description = "List has been successfully retrieved.")
@Operation(summary = "List all Silos", description = "Returns all Silo entities stored in the database.")
Expand Down Expand Up @@ -68,7 +73,9 @@ public SiloEntity update(
@ApiResponse(responseCode = "200", description = "Silo has been successfully deleted.")
@ApiResponse(responseCode = "404", description = "Silo not found.", content = @Content)
@Operation(summary = "Delete a Silo", description = "Deletes one Silo entity from its UUID stored in the database.")
public void delete(@PathVariable("uuid") @Parameter(description = "Identifier of the Silo") String uuid) {
public void delete(@PathVariable("uuid") @Parameter(description = "Identifier of the Silo") UUID uuid)
throws Exception {
siloService.removeEmbeddings(uuid);
super.delete(uuid, siloRepository);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ai.dragon.controller.api.ragapi;

import java.util.List;
import java.util.UUID;
import java.util.ArrayList;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PathVariable;
Expand All @@ -10,6 +12,10 @@
import org.springframework.web.bind.annotation.RestController;

import ai.dragon.service.EmbeddingStoreService;
import ai.dragon.util.embedding.search.EmbeddingMatchResponse;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
Expand All @@ -26,10 +32,19 @@ public class SearchRagApiController {
@PostMapping("/documents/silo/{uuid:[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}}")
@ApiResponse(responseCode = "200", description = "Documents have been successfully retrieved.")
@Operation(summary = "Search documents inside a Silo", description = "Search documents from the Silo.")
public void searchDocumentsInSilo(
public List<EmbeddingMatchResponse> searchDocumentsInSilo(
@PathVariable("uuid") @Parameter(description = "Identifier of the Silo") UUID uuid,
@RequestBody String query)
throws Exception {
embeddingStoreService.query(uuid, query);
List<EmbeddingMatchResponse> searchResults = new ArrayList<>();
EmbeddingSearchResult<TextSegment> embeddingSearchResult = embeddingStoreService.query(uuid, query);
for (EmbeddingMatch<TextSegment> embeddingMatch : embeddingSearchResult.matches()) {
searchResults.add(EmbeddingMatchResponse.builder()
.score(embeddingMatch.score())
//.metadata(embeddingMatch.embedded().metadata())
.text(embeddingMatch.embedded().text())
.build());
}
return searchResults;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

import ai.dragon.util.embedding.store.inmemory.persist.PersistInMemoryEmbeddingStore;
import lombok.Data;

@Data
Expand All @@ -10,6 +11,6 @@ public class PersistInMemoryEmbeddingStoreSettings {
private Integer flushSecs;

public PersistInMemoryEmbeddingStoreSettings() {
flushSecs = 60;
flushSecs = PersistInMemoryEmbeddingStore.DEFAULT_FLUSH_SECS;
}
}
27 changes: 11 additions & 16 deletions backend/src/main/java/ai/dragon/service/EmbeddingStoreService.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;

Expand Down Expand Up @@ -88,13 +86,12 @@ public void closeAllEmbeddingStores() {
}
}

public void clearEmbeddingStore(UUID siloUuid) {
if (embeddingStores.containsKey(siloUuid)) {
embeddingStores.get(siloUuid).removeAll();
}
public void clearEmbeddingStore(UUID siloUuid) throws Exception {
EmbeddingStore<TextSegment> embeddingStore = retrieveEmbeddingStore(siloUuid);
embeddingStore.removeAll();
}

public void query(UUID siloUuid, String query) throws Exception {
public EmbeddingSearchResult<TextSegment> query(UUID siloUuid, String query) throws Exception {
SiloEntity siloEntity = siloRepository.getByUuid(siloUuid).orElseThrow();
EmbeddingStore<TextSegment> embeddingStore = retrieveEmbeddingStore(siloUuid);
EmbeddingModel embeddingModel = embeddingModelService.modelForEntity(siloEntity);
Expand All @@ -105,25 +102,23 @@ public void query(UUID siloUuid, String query) throws Exception {
// .filter(onlyForUser1)
.maxResults(10)
.build();
EmbeddingSearchResult<TextSegment> embeddingSearchResult1 = embeddingStore.search(embeddingSearchRequest1);
for (EmbeddingMatch<TextSegment> embeddingMatch : embeddingSearchResult1.matches()) {
System.out.println("=> " + embeddingMatch.score() + " : " +
embeddingMatch.embedded().metadata());
System.out.println(embeddingMatch.embedded().text());
System.out.println("=====");
}
return embeddingStore.search(embeddingSearchRequest1);
}

private EmbeddingStore<TextSegment> buildEmbeddingStore(SiloEntity siloEntity) throws Exception {
switch (siloEntity.getVectorStoreType()) {
case InMemoryEmbeddingStore:
return new InMemoryEmbeddingStore<>();
return PersistInMemoryEmbeddingStore.builder().build();
case PersistInMemoryEmbeddingStore:
PersistInMemoryEmbeddingStoreSettings storeSettings = IniSettingUtil.convertIniSettingsToObject(
siloEntity.getVectorStoreSettings(), PersistInMemoryEmbeddingStoreSettings.class);
File vectorFile = new File(directoryStructureComponent.directoryFor("vector"),
siloEntity.getUuid().toString() + ".json");
return PersistInMemoryEmbeddingStore.createFromFileAndSettings(vectorFile, storeSettings);
return PersistInMemoryEmbeddingStore
.builder()
.flushSecs(storeSettings.getFlushSecs())
.persistFile(vectorFile)
.build();
default:
throw new UnsupportedOperationException(
String.format("VectorStoreType not supported : %s", siloEntity.getVectorStoreType()));
Expand Down
6 changes: 6 additions & 0 deletions backend/src/main/java/ai/dragon/service/IngestorService.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,15 @@ public void runSiloIngestion(SiloEntity siloEntity, Consumer<Integer> progressCa
.message(String.format("Listing documents using '%s' Ingestor Loader...", ingestorLoader.getClass()))
.build());
List<Document> documents = ingestorLoader.listDocuments();
// TODO ?
/*logCallback.accept(SiloIngestLoaderLogMessage.builder()
.message(String.format("Cleaning all current embeddings of Silo '%s'...", siloEntity.getUuid()))
.build());
embeddingStoreService.clearEmbeddingStore(siloEntity.getUuid());*/
logCallback.accept(SiloIngestLoaderLogMessage.builder()
.message(String.format("Will ingest %d documents to Silo...", documents.size())).build());
ingestDocumentsToSilo(documents, siloEntity, progressCallback, logCallback);
// TODO Need to clean embeddings unlinked to documents listing
}

private void ingestDocumentsToSilo(List<Document> documents, SiloEntity siloEntity,
Expand Down
10 changes: 5 additions & 5 deletions backend/src/main/java/ai/dragon/service/SiloService.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ public void onChangeEvent(CollectionEventInfo<?> collectionEventInfo, SiloEntity
switch (collectionEventInfo.getEventType()) {
case Remove:
removeFarmLinks(entity);
removeEmbeddings(entity);
break;
default:
break;
Expand All @@ -57,6 +56,11 @@ public void rebuildSilo(UUID uuid) {
siloJobService.startSiloIngestorJobNow(entity);
}

public void removeEmbeddings(UUID uuid) throws Exception {
SiloEntity entity = siloRepository.getByUuid(uuid).orElseThrow();
embeddingStoreService.clearEmbeddingStore(entity.getUuid());
}

private void removeFarmLinks(SiloEntity entity) {
for (FarmEntity farm : farmRepository.find()) {
if (farm.getSilos().contains(entity.getUuid())) {
Expand All @@ -65,8 +69,4 @@ private void removeFarmLinks(SiloEntity entity) {
}
}
}

private void removeEmbeddings(SiloEntity entity) {
embeddingStoreService.clearEmbeddingStore(entity.getUuid());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package ai.dragon.util.embedding.search;

import java.util.Map;

import lombok.Builder;
import lombok.Getter;

@Builder
@Getter
public class EmbeddingMatchResponse {
private Double score;
private String text;
private Map<String, String> metadata;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.nio.file.Files;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
Expand All @@ -16,7 +17,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.dragon.properties.store.PersistInMemoryEmbeddingStoreSettings;
import ai.dragon.util.Debouncer;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
Expand All @@ -29,37 +29,62 @@
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;
import lombok.Builder;

// Based on :
// Initialy based on:
// https://raw.githubusercontent.com/langchain4j/langchain4j/main/langchain4j/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java
@Builder
public class PersistInMemoryEmbeddingStore implements EmbeddingStore<TextSegment> {
public static final int DEFAULT_FLUSH_SECS = 60;

private final Logger logger = LoggerFactory.getLogger(this.getClass());
private final CopyOnWriteArrayList<PersistInMemoryEntry<TextSegment>> entries = new CopyOnWriteArrayList<>();
private final Debouncer debouncer = new Debouncer();

private File persistFile;
private PersistInMemoryEmbeddingStoreSettings settings;
private Integer flushSecs = DEFAULT_FLUSH_SECS;

private PersistInMemoryEmbeddingStore() {
public static Builder builder() {
return new Builder();
}

public static PersistInMemoryEmbeddingStore createFromFileAndSettings(File persistFile,
PersistInMemoryEmbeddingStoreSettings settings) {
PersistInMemoryEmbeddingStore store = new PersistInMemoryEmbeddingStore();
store.persistFile = persistFile;
store.restoreFromFileNow();
return store;
public static class Builder extends PersistInMemoryEmbeddingStoreBuilder {
Builder() {
super();
}

@Override
public PersistInMemoryEmbeddingStore build() {
PersistInMemoryEmbeddingStore embeddingStore = super.build();
if (embeddingStore.persistFile != null) {
embeddingStore.restoreFromFile();
}
return embeddingStore;
}
}

// TODO Remove funcs...
@Override
public void remove(String id) {
entries.removeIf(entry -> entry.getId().equals(id));
flushToDisk();
}

public void flushToDisk() {
debouncer.debounce(Void.class, new Runnable() {
@Override
public void run() {
flushToDiskNow();
}
}, settings.getFlushSecs(), TimeUnit.SECONDS);
@Override
public void removeAll(Collection<String> ids) {
entries.removeIf(entry -> ids.contains(entry.getId()));
flushToDisk();
}

@Override
public void removeAll(Filter filter) {
entries.removeIf(entry -> filter.test(entry.getEmbedded().metadata()));
flushToDisk();
}

@Override
public void removeAll() {
entries.clear();
flushToDisk();
}

@Override
Expand Down Expand Up @@ -113,7 +138,26 @@ private List<String> add(List<PersistInMemoryEntry<TextSegment>> newEntries) {
.collect(Collectors.toList());
}

private void restoreFromFileNow() {
private void flushToDisk() {
if (persistFile == null) {
return;
}
int flushSecs = this.flushSecs == null ? DEFAULT_FLUSH_SECS : this.flushSecs;
debouncer.debounce(Void.class, new Runnable() {
@Override
public void run() {
if (persistFile == null) {
return;
}
flushToDiskNow();
}
}, flushSecs, TimeUnit.SECONDS);
}

private void restoreFromFile() {
if (persistFile == null) {
return;
}
try {
logger.debug(String.format("Restoring embeddings from file : %s", persistFile));
if (persistFile.exists()) {
Expand All @@ -131,8 +175,12 @@ private void restoreFromFileNow() {
}

private void flushToDiskNow() {
if (persistFile == null) {
logger.warn("Won't flush to disk because persistFile is null.");
return;
}
logger.debug(String.format("Flushing %d embeddings to file : %s", entries.size(), persistFile));
try {
logger.debug(String.format("Flushing %d embeddings to file : %s", entries.size(), persistFile));
String json = codec().toJson(this.entries);
Files.write(persistFile.toPath(), json.getBytes(), StandardOpenOption.CREATE,
StandardOpenOption.TRUNCATE_EXISTING);
Expand Down
Binary file not shown.

0 comments on commit 07484e1

Please sign in to comment.