Skip to content

Commit

Permalink
Fix break segments and overlap (#58)
Browse files Browse the repository at this point in the history
* Fix the issue of dropping cuts

* fix overlap, still has some problems

* refactor is_overlap

* Fix overlap

* release v0.10
pkufool authored Jan 10, 2024
1 parent 98aad14 commit 6847d91
Showing 7 changed files with 223 additions and 95 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(textsearch)

set(TS_VERSION "0.9")
set(TS_VERSION "0.10")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
10 changes: 6 additions & 4 deletions examples/libriheavy/matching.py
Original file line number Diff line number Diff line change
@@ -88,13 +88,14 @@ def get_params() -> AttributeDict:
# you can find the docs in textsearch/match.py#align_queries
"num_close_matches": 2,
"segment_length": 5000,
"reference_length_difference": 0.1,
"reference_length_difference": 0.4,
"min_matched_query_ratio": 0.33,
# parameters for splitting aligned queries
# you can find the docs in textsearch/match.py#split_aligned_queries
"preceding_context_length": 1000,
"timestamp_position": "current",
"silence_length_to_break": 0.45,
"overlap_ratio": 0.4,
"min_duration": 2,
"max_duration": 30,
"expected_duration": (5, 20),
@@ -188,6 +189,7 @@ def load_data(
books.append(book)

if not transcripts:
logging.warning(f"No transcripts found.")
return {}

logging.debug(f"Worker[{worker_index}] loading cuts and books done.")
@@ -321,6 +323,7 @@ def split(
preceding_context_length=params.preceding_context_length,
timestamp_position=params.timestamp_position,
silence_length_to_break=params.silence_length_to_break,
overlap_ratio=params.overlap_ratio,
min_duration=params.min_duration,
max_duration=params.max_duration,
expected_duration=params.expected_duration,
@@ -457,9 +460,7 @@ def main():
batch_cuts = []
logging.info(f"Start processing...")
for i, cut in enumerate(raw_cuts):
if len(batch_cuts) < params.batch_size:
batch_cuts.append(cut)
else:
if len(batch_cuts) >= params.batch_size:
process_one_batch(
params,
batch_cuts=batch_cuts,
@@ -469,6 +470,7 @@ def main():
)
batch_cuts = []
logging.info(f"Number of cuts have been loaded is {i}")
batch_cuts.append(cut)
if len(batch_cuts):
process_one_batch(
params,
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fasttextsearch"
version = "0.9"
version = "0.10"
authors = [
{ name="Next-gen Kaldi development team", email="wkang.pku@gmail.com" },
]
1 change: 1 addition & 0 deletions textsearch/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ endfunction()
if(TS_ENABLE_TESTS)
set(test_srcs
test_find_close_matches.py
test_is_overlap.py
test_levenshtein_distance.py
test_match.py
test_row_ids_to_row_splits.py
80 changes: 80 additions & 0 deletions textsearch/python/tests/test_is_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
#
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To run this single test, use
#
# ctest --verbose -R match_test_py

import unittest

from textsearch.utils import is_overlap


class TestOverlap(unittest.TestCase):
def test_is_overlap(self):
candidates = [
[20, 30],
[15, 25],
[10, 21.1],
[1, 10],
[60, 70],
[65, 73],
[68.5, 85],
[25, 35],
[45, 55],
[20, 25],
[21, 25],
[34.5, 46.5],
[35, 46.1],
[25, 35],
[26, 34],
[44, 70.5],
]
selected_ranges: List[Tuple[float, float]] = []
selected_indexes: List[int] = []
segments = []
overlapped_segments = []
for r in candidates:
status, index = is_overlap(
selected_ranges,
selected_indexes,
query=(r[0], r[1]),
segment_index=len(segments),
overlap_ratio=0.1,
)
if status:
if index is not None:
overlapped_segments.append(index)
segments.append(r)
else:
segments.append(r)
for index in sorted(overlapped_segments, reverse=True):
segments.pop(index)
expected_segments = [
[10, 21.1],
[1, 10],
[68.5, 85],
[25, 35],
[21, 25],
[35, 46.1],
]
assert segments == expected_segments


if __name__ == "__main__":
unittest.main()
129 changes: 63 additions & 66 deletions textsearch/python/textsearch/match.py
Original file line number Diff line number Diff line change
@@ -145,44 +145,6 @@ def _break_query(
# [(query_start, query_end, target_start, target_end)]
segments: List[Tuple[int, int, int, int]] = []

def add_segments(
query_start,
query_end,
target_start,
target_end,
segment_length,
segments,
):
num_chunk = (query_end - query_start) // segment_length
if num_chunk > 0:
for i in range(num_chunk):
real_target_end = (
target_start + segment_length
if target_start + segment_length < target_end
else target_end
)
segments.append(
(
query_start,
query_start + segment_length,
target_start,
real_target_end,
)
)
query_start += segment_length
target_start += segment_length
# if the remaining part is smaller than segment_length // 4, we will
# append it to the last segment rather than creating a new segment.
if segments and query_end - query_start < segment_length // 4:
segments[-1] = (
segments[-1][0],
query_end,
segments[-1][2],
target_end,
)
else:
segments.append((query_start, query_end, target_start, target_end))

target_doc_id = sourced_text.doc[matched_points[max_item[0]][1]]
target_base = sourced_text.doc_splits[target_doc_id]
next_target_base = sourced_text.doc_splits[target_doc_id + 1]
@@ -207,7 +169,15 @@ def add_segments(
for ind in range(max_item[0], max_item[1]):
if matched_points[ind][0] - prev_break_point[0] > segment_length:
if ind == max_item[0]:
continue
segments.append(
(
prev_break_point[0],
matched_points[ind][0],
prev_break_point[1],
matched_points[ind][1],
)
)
prev_break_point = matched_points[ind]
else:
query_start = prev_break_point[0]
query_end = matched_points[ind - 1][0]
@@ -217,17 +187,16 @@ def add_segments(
ratio = (target_end - target_start) / (query_end - query_start)
half = reference_length_difference / 2
if ratio < 1 - half or ratio > 1 + half:
logging.debug(
f"Invalid ratio for segment: "
f"{query_start, query_end, target_start, target_end}"
)
continue

prev_break_point = (query_end, target_end)
add_segments(
query_start,
query_end,
target_start,
target_end,
segment_length,
segments,
segments.append(
(query_start, query_end, target_start, target_end)
)
prev_break_point = (query_end, target_end)

query_start, target_start = prev_break_point
query_end = next_query_base
@@ -248,14 +217,7 @@ def add_segments(
else:
segments.append((query_start, query_end, target_start, target_end))
else:
add_segments(
query_start,
query_end,
target_start,
target_end,
segment_length,
segments,
)
segments.append((query_start, query_end, target_start, target_end))
return segments


@@ -470,15 +432,22 @@ def align_queries(
# in sourced_text
matched_points = get_longest_increasing_pairs(seq1, seq2)

if len(matched_points) == 0:
continue

# In the algorithm of `find_close_matches`,
# `sourced_text.binary_text.size - 1` means no close_matches
trim_pos = len(matched_points) - 1
while matched_points[trim_pos][1] == sourced_text.binary_text.size - 1:
trim_pos -= 1
matched_points = matched_points[0:trim_pos]
if len(matched_points) != 0:
trim_pos = len(matched_points) - 1
while (
matched_points[trim_pos][1] == sourced_text.binary_text.size - 1
):
trim_pos -= 1
matched_points = matched_points[0:trim_pos]

if len(matched_points) == 0:
logging.warning(
f"Skipping query {q}, no matched points between query and target"
f"in close_matches."
)
continue

# The following code guarantees the matched points are in the same
# reference document. We will choose the reference document that matches
@@ -988,6 +957,7 @@ def _split_into_segments(
preceding_context_length: int = 1000,
timestamp_position: str = "middle", # previous, middle, current
silence_length_to_break: float = 0.6, # in second
overlap_ratio: float = 0.35, # percentage
min_duration: float = 2, # in second
max_duration: float = 30, # in second
expected_duration: Tuple[float, float] = (5, 20), # in second
@@ -1024,6 +994,10 @@ def _split_into_segments(
preceding or succeeding silence length greater than this value, we will
add it as a possible breaking point.
Caution: Only be used when there are no punctuations in target_source.
overlap_ratio:
The ratio of overlapping part to the query or existing segments. If the
ratio is greater than `overlap_ratio` we will drop the query or existing
segment.
min_duration:
The minimum duration (in second) allowed for a segment.
max_duration:
@@ -1079,13 +1053,28 @@ def _split_into_segments(
# Handle the overlapping
# Caution: Don't modified selected_ranges, it will be manipulated in
# `is_overlap` and will be always kept sorted.
selected_ranges: List[Tuple[int, int]] = []
# Don't modified selected_indexes also, it will be manipulated in `is_overlap`
# according to selected_ranges.
selected_ranges: List[Tuple[float, float]] = []
selected_indexes: List[int] = []
segments = []
overlapped_segments = []
for r in candidates:
if not is_overlap(
selected_ranges, query=(r[0], r[1]), overlap_ratio=0.5
):
status, index = is_overlap(
selected_ranges,
selected_indexes,
query=(aligns[r[0]]["hyp_time"], aligns[r[1]]["hyp_time"]),
segment_index=len(segments),
overlap_ratio=overlap_ratio,
)
if status:
if index is not None:
overlapped_segments.append(index)
segments.append(r)
else:
segments.append(r)
for index in sorted(overlapped_segments, reverse=True):
segments.pop(index)

results = []

@@ -1217,6 +1206,7 @@ def _split_helper(
preceding_context_length: int,
timestamp_position: str,
silence_length_to_break: float,
overlap_ratio: float,
min_duration: float,
max_duration: float,
expected_duration: Tuple[float, float],
@@ -1233,6 +1223,7 @@ def _split_helper(
preceding_context_length=preceding_context_length,
timestamp_position=timestamp_position,
silence_length_to_break=silence_length_to_break,
overlap_ratio=overlap_ratio,
min_duration=min_duration,
max_duration=max_duration,
expected_duration=expected_duration,
@@ -1250,6 +1241,7 @@ def split_aligned_queries(
preceding_context_length: int = 1000,
timestamp_position: str = "current", # previous, middle, current
silence_length_to_break: float = 0.6, # in second
overlap_ratio: float = 0.35,
min_duration: float = 2, # in second
max_duration: float = 30, # in second
expected_duration: Tuple[float, float] = (5, 20), # in second
@@ -1288,6 +1280,10 @@ def split_aligned_queries(
preceding or succeeding silence length greater than this value, we will
add it as a possible breaking point.
Caution: Only be used when there are no punctuations in target_source.
overlap_ratio:
The ratio of overlapping part to the query or existing segments. If the
ratio is greater than `overlap_ratio` we will drop the query or existing
segment.
min_duration:
The minimum duration (in second) allowed for a segment.
max_duration:
@@ -1342,6 +1338,7 @@ def split_aligned_queries(
preceding_context_length,
timestamp_position,
silence_length_to_break,
overlap_ratio,
min_duration,
max_duration,
expected_duration,
94 changes: 71 additions & 23 deletions textsearch/python/textsearch/utils.py
Original file line number Diff line number Diff line change
@@ -10,8 +10,8 @@
Pathlike = Union[str, Path]

PUCTUATIONS = {
"all": set("',.;?!():-<>/\"。;?!():-《》【】”“"),
"eos": set(".?!。?!"),
"all": set("'.;?!():-<>/\"。;?!():-《》【】”“"),
"eos": set(".?,,!。?!"),
"left": set("\"'(<《【“"),
"right": set("\"')>》】”"),
}
@@ -108,53 +108,101 @@ def row_ids_to_row_splits(row_ids: np.ndarray) -> np.ndarray:


def is_overlap(
ranges: List[Tuple[int, int]],
query: Tuple[int, int],
ranges: List[Tuple[float, float]],
indexes: List[int],
query: Tuple[float, float],
segment_index: int,
overlap_ratio: float = 0.25,
) -> bool:
) -> Tuple[bool, Union[int, None]]:
"""
Return if the given range overlaps with the existing ranges.
Return True if the given range overlaps with the existing ranges.
Caution:
`ranges` will be modified in this function (when returning False)
`ranges` and `indexes` will be modified in this function.
Note: overlapping here means the length of overlapping area is greater than
some threshold (currently, the threshold is `overlap_ratio` multiply the length
of the shorter overlapping ranges).
of the query or existing ranges).
Args:
ranges:
The existing ranges, it is sorted in ascending order on input, and we will
keep it sorted in this function.
indexes:
The index (into the selected segments) of each range belongs to.
query:
The given range.
segment_index:
The index (into the selected segments) of query to be inserted.
overlap_ratio:
The ratio of overlapping part to the query or existing segments. If the
ratio is greater than `overlap_ratio` we will drop the query or existing
segment.
Return:
Return True if having overlap otherwise False.
Return (False, None) if no overlapping between query and existing ranges.
Return (True, None) if the ratio of overlapping part to query is greater
than `overlap_ratio`.
Return (True, dindex) if the ratio of overlapping part to existing range
is greater than `overlap_ratio`, `dindex` is the index (can get from indexes)
of the existing range.
"""
is_overlap = False
index = bisect_left(ranges, query)
if index == 0:
if ranges:
is_overlap = (
query[1] - ranges[0][0] > (query[1] - query[0]) * overlap_ratio
)
elif index == len(ranges):
if not ranges:
ranges.insert(index, query)
indexes.insert(index, segment_index)
return False, None

# overlapping on query
if index > 0:
is_overlap = (
ranges[index - 1][1] - query[0]
> (query[1] - query[0]) * overlap_ratio
)
else:
if is_overlap:
return True, None

if index < len(ranges):
is_overlap = (
ranges[index - 1][1] - query[0]
> (query[1] - query[0]) * overlap_ratio
) or (
query[1] - ranges[index][0] > (query[1] - query[0]) * overlap_ratio
)
if is_overlap:
return True, None

if not is_overlap:
# overlapping on existing ranges
is_overlap_left = False
if index > 0:
is_overlap_left = (
ranges[index - 1][1] - query[0]
> (ranges[index - 1][1] - ranges[index - 1][0]) * overlap_ratio
)

is_overlap_right = False
if index < len(ranges):
is_overlap_right = (
query[1] - ranges[index][0]
> (ranges[index][1] - ranges[index][0]) * overlap_ratio
)

if is_overlap_left or is_overlap_right:
if is_overlap_left and not is_overlap_right:
ranges.insert(index, query)
ranges.pop(index - 1)
indexes.insert(index, segment_index)
dindex = indexes.pop(index - 1)
return True, dindex
elif is_overlap_right and not is_overlap_left:
ranges.insert(index, query)
ranges.pop(index + 1)
indexes.insert(index, segment_index)
dindex = indexes.pop(index + 1)
return True, dindex
else:
return True, None
else:
ranges.insert(index, query)
return is_overlap
indexes.insert(index, segment_index)
return False, None


def is_punctuation(c: str, eos_only: bool = False) -> bool:
@@ -165,7 +213,7 @@ def is_punctuation(c: str, eos_only: bool = False) -> bool:
c:
The given character.
eos_only:
If True the punctuations are only those indicating end of a sentence (.?! for now).
If True the punctuations are only those indicating end of a sentence (,.?! for now).
"""
if eos_only:
return c in PUCTUATIONS["eos"]

0 comments on commit 6847d91

Please sign in to comment.