Skip to content

Commit

Permalink
Added RecursiveCharacterSplitter (#14)
Browse files Browse the repository at this point in the history
* Added RecursiveCharacterSplitter

Found the description on LangChain website and created my implementation, seems to be working. It will split documents recursively by different characters - starting with "\n\n", then "\n", then " ". This is nice because it will try to keep all the semantically relevant content in the same place for as long as possible.

Orig from langchain: https://js.langchain.com/docs/modules/indexes/text_splitters/examples/recursive_character

* Reduce search performance test for CI

---------

Co-authored-by: ZachNagengast <[email protected]>
  • Loading branch information
LexiestLeszek and ZachNagengast authored Jul 5, 2023
1 parent 789c72d commit 5f5b00a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//
// RecursiveCharacterSplitter.swift
//
// Created by Leszek Mielnikow on 03/07/2023.
//

import Foundation
import SimilaritySearchKit

public class RecursiveCharacterSplitter: TextSplitterProtocol {
let characterSplitter: CharacterSplitter

public init() {
characterSplitter = CharacterSplitter()
}

public func split(text: String, chunkSize: Int = 100, overlapSize: Int = 0) -> ([String], [[String]]?) {
let separators = ["\n\n", "\n", ".", " "]

for separator in separators {
let splits = text.components(separatedBy: separator)
let (isValid, splitTokens) = isSplitValid(chunks: splits, maxChunkSize: chunkSize)

if isValid {
var chunks: [String] = []
var chunkTokens: [[String]] = []

var currentChunkTokens: [String] = []
var currentChunkSize: Int = 0
var currentChunkSplit: String = ""

for (idx, tokens) in splitTokens.enumerated() {
let tokensSize = tokens.count

if currentChunkSize + tokensSize < chunkSize {
currentChunkTokens.append(contentsOf: tokens)
currentChunkSize += tokensSize
currentChunkSplit += splits[idx] + separator

} else {
chunks.append(currentChunkSplit.trimmingCharacters(in: .whitespaces))
chunkTokens.append(characterSplitter.split(text: currentChunkSplit, chunkSize: chunkSize).0)

// reset current
currentChunkTokens = tokens
currentChunkSize = tokensSize
currentChunkSplit = splits[idx] + separator
}
}

// Add the last chunk if it's not empty
if !currentChunkSplit.isEmpty {
chunks.append(currentChunkSplit.trimmingCharacters(in: .whitespaces))
chunkTokens.append(characterSplitter.split(text: currentChunkSplit, chunkSize: chunkSize).0)
}

return (chunks, chunkTokens)
}
}

return ([], [])
}

// MARK: - Helpers

private func isSplitValid(chunks: [String], maxChunkSize: Int) -> (Bool, [[String]]) {
var splitTokens: [[String]] = []

for chunk in chunks {
let tokens = characterSplitter.split(text: chunk, chunkSize: maxChunkSize).0
if chunk.count > maxChunkSize {
return (false, [])
}
splitTokens.append(tokens)
}

return (true, splitTokens)
}
}
2 changes: 1 addition & 1 deletion Tests/SimilaritySearchKitTests/BenchmarkTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class BenchmarkTests: XCTestCase {
}

func testDistilbertPerformanceSearch() {
let testAmount = 10
let testAmount = 2
let passageIds = Array(0..<testAmount).map { _ in UUID().uuidString }
let passageTexts = Array(MSMarco.passageTexts[0..<testAmount])
let passageUrls = MSMarco.passageUrls[0..<testAmount].map { url in ["source": url] }
Expand Down

0 comments on commit 5f5b00a

Please sign in to comment.