Skip to content

Commit

Permalink
Inmemory file backed (#29)
Browse files Browse the repository at this point in the history
* 🚧 PersistInMemoryEmbeddingStore

* PersistInMemoryEmbeddingStore
  • Loading branch information
isontheline authored Jun 3, 2024
1 parent d4bbd69 commit 722fee2
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package ai.dragon.enumeration;

public enum VectorStoreType {
InMemoryEmbeddingStore("InMemoryEmbeddingStore");
InMemoryEmbeddingStore("InMemoryEmbeddingStore"),
PersistInMemoryEmbeddingStore("PersistInMemoryEmbeddingStore");

private String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public class InMemoryEmbeddingStoreSettings {
private String persistance;
public class PersistInMemoryEmbeddingStoreSettings {
private Integer flushSecs;

public InMemoryEmbeddingStoreSettings() {
persistance = ":memory:";
public PersistInMemoryEmbeddingStoreSettings() {
flushSecs = 60;
}
}
42 changes: 9 additions & 33 deletions backend/src/main/java/ai/dragon/service/EmbeddingStoreService.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ai.dragon.service;

import java.io.File;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
Expand All @@ -13,9 +12,10 @@
import ai.dragon.component.DirectoryStructureComponent;
import ai.dragon.entity.SiloEntity;
import ai.dragon.listener.EntityChangeListener;
import ai.dragon.properties.store.InMemoryEmbeddingStoreSettings;
import ai.dragon.properties.store.PersistInMemoryEmbeddingStoreSettings;
import ai.dragon.repository.SiloRepository;
import ai.dragon.util.IniSettingUtil;
import ai.dragon.util.embedding.store.inmemory.persist.PersistInMemoryEmbeddingStore;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
Expand All @@ -41,7 +41,6 @@ public class EmbeddingStoreService {
private DirectoryStructureComponent directoryStructureComponent;

private EntityChangeListener<SiloEntity> entityChangeListener;
private Map<UUID, Path> inMemoryEmbededdingStorePersistPaths = new HashMap<>();

@PostConstruct
private void init() {
Expand Down Expand Up @@ -79,22 +78,10 @@ public EmbeddingStore<TextSegment> retrieveEmbeddingStore(UUID siloUuid) throws
public void closeEmbeddingStore(UUID siloUuid) {
EmbeddingStore<TextSegment> embeddingStore = embeddingStores.get(siloUuid);
if (embeddingStore != null) {
persistEmbeddingStore(siloUuid);
embeddingStores.remove(siloUuid);
}
}

public void persistEmbeddingStore(UUID siloUuid) {
EmbeddingStore<TextSegment> embeddingStore = embeddingStores.get(siloUuid);
if (embeddingStore instanceof InMemoryEmbeddingStore) {
InMemoryEmbeddingStore<TextSegment> inMemoryEmbeddingStore = (InMemoryEmbeddingStore<TextSegment>) embeddingStore;
Path vectorFile = inMemoryEmbededdingStorePersistPaths.get(siloUuid);
if (vectorFile != null) {
inMemoryEmbeddingStore.serializeToFile(vectorFile);
}
}
}

public void closeAllEmbeddingStores() {
for (UUID siloUuid : embeddingStores.keySet()) {
closeEmbeddingStore(siloUuid);
Expand Down Expand Up @@ -128,29 +115,18 @@ public void query(UUID siloUuid, String query) throws Exception {
}

private EmbeddingStore<TextSegment> buildEmbeddingStore(SiloEntity siloEntity) throws Exception {
EmbeddingStore<TextSegment> embeddingStore = null;

switch (siloEntity.getVectorStoreType()) {
case InMemoryEmbeddingStore:
InMemoryEmbeddingStoreSettings storeSettings = IniSettingUtil.convertIniSettingsToObject(
siloEntity.getVectorStoreSettings(), InMemoryEmbeddingStoreSettings.class);
String persistance = storeSettings.getPersistance();
embeddingStore = new InMemoryEmbeddingStore<>();
if (!":memory:".equals(persistance)) {
File vectorFile = new File(directoryStructureComponent.directoryFor("vector"),
siloEntity.getUuid().toString() + ".json");
Path vectorPath = vectorFile.toPath();
inMemoryEmbededdingStorePersistPaths.put(siloEntity.getUuid(), vectorPath);
if (vectorFile.exists()) {
embeddingStore = InMemoryEmbeddingStore.fromFile(vectorPath);
}
}
break;
return new InMemoryEmbeddingStore<>();
case PersistInMemoryEmbeddingStore:
PersistInMemoryEmbeddingStoreSettings storeSettings = IniSettingUtil.convertIniSettingsToObject(
siloEntity.getVectorStoreSettings(), PersistInMemoryEmbeddingStoreSettings.class);
File vectorFile = new File(directoryStructureComponent.directoryFor("vector"),
siloEntity.getUuid().toString() + ".json");
return PersistInMemoryEmbeddingStore.createFromFileAndSettings(vectorFile, storeSettings);
default:
throw new UnsupportedOperationException(
String.format("VectorStoreType not supported : %s", siloEntity.getVectorStoreType()));
}

return embeddingStore;
}
}
3 changes: 0 additions & 3 deletions backend/src/main/java/ai/dragon/service/IngestorService.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ public void runSiloIngestion(SiloEntity siloEntity, Consumer<Integer> progressCa
logCallback.accept(SiloIngestLoaderLogMessage.builder()
.message(String.format("Will ingest %d documents to Silo...", documents.size())).build());
ingestDocumentsToSilo(documents, siloEntity, progressCallback, logCallback);
logCallback.accept(SiloIngestLoaderLogMessage.builder()
.message(String.format("Persisting the Embedding Store...", documents.size())).build());
embeddingStoreService.persistEmbeddingStore(siloEntity.getUuid());
}

private void ingestDocumentsToSilo(List<Document> documents, SiloEntity siloEntity,
Expand Down
37 changes: 37 additions & 0 deletions backend/src/main/java/ai/dragon/util/Debouncer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ai.dragon.util;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

// Credits : https://stackoverflow.com/a/38296055/8102448
public class Debouncer {
private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
private final ConcurrentHashMap<Object, Future<?>> delayedMap = new ConcurrentHashMap<>();

/**
* Debounces {@code callable} by {@code delay}, i.e., schedules it to be executed after {@code delay},
* or cancels its execution if the method is called with the same key within the {@code delay} again.
*/
public void debounce(final Object key, final Runnable runnable, long delay, TimeUnit unit) {
final Future<?> prev = delayedMap.put(key, scheduler.schedule(new Runnable() {
@Override
public void run() {
try {
runnable.run();
} finally {
delayedMap.remove(key);
}
}
}, delay, unit));
if (prev != null) {
prev.cancel(true);
}
}

public void shutdown() {
scheduler.shutdownNow();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ai.dragon.util.embedding.store.inmemory.persist;

import java.lang.reflect.Type;
import java.util.concurrent.CopyOnWriteArrayList;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.ToNumberPolicy;
import com.google.gson.reflect.TypeToken;

import dev.langchain4j.data.segment.TextSegment;

public class GsonPersistInMemoryEmbeddingStoreJsonCodec {
private static final Gson GSON = new GsonBuilder()
.setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE)
.create();

private static final Type TYPE = new TypeToken<CopyOnWriteArrayList<PersistInMemoryEntry<TextSegment>>>() {
}.getType();

public CopyOnWriteArrayList<PersistInMemoryEntry<TextSegment>> fromJson(String json) {
return GSON.fromJson(json, TYPE);
}

public String toJson(CopyOnWriteArrayList<PersistInMemoryEntry<TextSegment>> entries) {
return GSON.toJson(entries);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package ai.dragon.util.embedding.store.inmemory.persist;

import java.io.File;
import java.nio.file.Files;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.dragon.properties.store.PersistInMemoryEmbeddingStoreSettings;
import ai.dragon.util.Debouncer;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.filter.Filter;

// Based on :
// https://raw.githubusercontent.com/langchain4j/langchain4j/main/langchain4j/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java
public class PersistInMemoryEmbeddingStore implements EmbeddingStore<TextSegment> {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
private final CopyOnWriteArrayList<PersistInMemoryEntry<TextSegment>> entries = new CopyOnWriteArrayList<>();
private final Debouncer debouncer = new Debouncer();

private File persistFile;
private PersistInMemoryEmbeddingStoreSettings settings;

private PersistInMemoryEmbeddingStore() {
}

public static PersistInMemoryEmbeddingStore createFromFileAndSettings(File persistFile,
PersistInMemoryEmbeddingStoreSettings settings) {
PersistInMemoryEmbeddingStore store = new PersistInMemoryEmbeddingStore();
store.persistFile = persistFile;
store.restoreFromFileNow();
return store;
}

// TODO Remove funcs...

public void flushToDisk() {
debouncer.debounce(Void.class, new Runnable() {
@Override
public void run() {
flushToDiskNow();
}
}, settings.getFlushSecs(), TimeUnit.SECONDS);
}

@Override
public String add(Embedding embedding) {
String id = Utils.randomUUID();
add(id, embedding);
return id;
}

@Override
public void add(String id, Embedding embedding) {
add(id, embedding, null);
}

@Override
public String add(Embedding embedding, TextSegment embedded) {
String id = Utils.randomUUID();
add(id, embedding, embedded);
return id;
}

public void add(String id, Embedding embedding, TextSegment embedded) {
entries.add(new PersistInMemoryEntry<>(id, embedding, embedded));
flushToDisk();
}

@Override
public List<String> addAll(List<Embedding> embeddings) {
List<PersistInMemoryEntry<TextSegment>> newEntries = embeddings.stream()
.map(embedding -> new PersistInMemoryEntry<TextSegment>(Utils.randomUUID(), embedding))
.collect(Collectors.toList());
return add(newEntries);
}

@Override
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
}
List<PersistInMemoryEntry<TextSegment>> newEntries = IntStream.range(0, embeddings.size())
.mapToObj(i -> new PersistInMemoryEntry<>(Utils.randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(Collectors.toList());
return add(newEntries);
}

private List<String> add(List<PersistInMemoryEntry<TextSegment>> newEntries) {
entries.addAll(newEntries);
flushToDisk();
return newEntries.stream()
.map(entry -> entry.getId())
.collect(Collectors.toList());
}

private void restoreFromFileNow() {
try {
logger.debug(String.format("Restoring embeddings from file : %s", persistFile));
if (persistFile.exists()) {
String json = new String(Files.readAllBytes(persistFile.toPath()));
CopyOnWriteArrayList<PersistInMemoryEntry<TextSegment>> restoredEntries = codec().fromJson(json);
entries.addAll(restoredEntries);
logger.info(String.format("Restored %d embeddings from file : %s", entries.size(), persistFile));
} else {
logger.info(String.format("No embeddings found in file : %s", persistFile));
}
} catch (Exception ex) {
logger.error(String.format("Failed to restore from file : %s", persistFile), ex);
throw new RuntimeException(ex);
}
}

private void flushToDiskNow() {
try {
logger.debug(String.format("Flushing %d embeddings to file : %s", entries.size(), persistFile));
String json = codec().toJson(this.entries);
Files.write(persistFile.toPath(), json.getBytes(), StandardOpenOption.CREATE,
StandardOpenOption.TRUNCATE_EXISTING);
} catch (Exception ex) {
logger.error(String.format("Failed to flush to file : %s", persistFile), ex);
throw new RuntimeException(ex);
}
}

private static GsonPersistInMemoryEmbeddingStoreJsonCodec codec() {
return new GsonPersistInMemoryEmbeddingStoreJsonCodec();
}

@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
Comparator<EmbeddingMatch<TextSegment>> comparator = Comparator.comparingDouble(EmbeddingMatch::score);
PriorityQueue<EmbeddingMatch<TextSegment>> matches = new PriorityQueue<>(comparator);
Filter filter = embeddingSearchRequest.filter();
for (PersistInMemoryEntry<TextSegment> entry : entries) {
if (filter != null && entry.getEmbedded() instanceof TextSegment) {
Metadata metadata = ((TextSegment) entry.getEmbedded()).metadata();
if (!filter.test(metadata)) {
continue;
}
}
double cosineSimilarity = CosineSimilarity.between(entry.getEmbedding(),
embeddingSearchRequest.queryEmbedding());
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= embeddingSearchRequest.minScore()) {
matches.add(new EmbeddingMatch<>(score, entry.getId(), entry.getEmbedding(), entry.getEmbedded()));
if (matches.size() > embeddingSearchRequest.maxResults()) {
matches.poll();
}
}
}
List<EmbeddingMatch<TextSegment>> result = new ArrayList<>(matches);
result.sort(comparator);
Collections.reverse(result);
return new EmbeddingSearchResult<>(result);
}
}
Loading

0 comments on commit 722fee2

Please sign in to comment.