From 6847d91c580eb4291e90767142a90c45f1c2e523 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 10 Jan 2024 19:47:34 +0800 Subject: [PATCH] Fix break segments and overlap (#58) * Fix the issue of dropping cuts * fix overlap, still has some problems * refactor is_overlap * Fix overlap * release v0.10 --- CMakeLists.txt | 2 +- examples/libriheavy/matching.py | 10 +- pyproject.toml | 2 +- textsearch/python/tests/CMakeLists.txt | 1 + textsearch/python/tests/test_is_overlap.py | 80 +++++++++++++ textsearch/python/textsearch/match.py | 129 ++++++++++----------- textsearch/python/textsearch/utils.py | 94 +++++++++++---- 7 files changed, 223 insertions(+), 95 deletions(-) create mode 100755 textsearch/python/tests/test_is_overlap.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f05726..ba5c122 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.12 FATAL_ERROR) project(textsearch) -set(TS_VERSION "0.9") +set(TS_VERSION "0.10") set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/examples/libriheavy/matching.py b/examples/libriheavy/matching.py index 4019127..f128c02 100755 --- a/examples/libriheavy/matching.py +++ b/examples/libriheavy/matching.py @@ -88,13 +88,14 @@ def get_params() -> AttributeDict: # you can find the docs in textsearch/match.py#align_queries "num_close_matches": 2, "segment_length": 5000, - "reference_length_difference": 0.1, + "reference_length_difference": 0.4, "min_matched_query_ratio": 0.33, # parameters for splitting aligned queries # you can find the docs in textsearch/match.py#split_aligned_queries "preceding_context_length": 1000, "timestamp_position": "current", "silence_length_to_break": 0.45, + "overlap_ratio": 0.4, "min_duration": 2, "max_duration": 30, "expected_duration": (5, 20), @@ -188,6 +189,7 @@ def load_data( books.append(book) if not transcripts: + logging.warning(f"No transcripts found.") return {} logging.debug(f"Worker[{worker_index}] loading cuts and books done.") @@ -321,6 +323,7 @@ def split( preceding_context_length=params.preceding_context_length, timestamp_position=params.timestamp_position, silence_length_to_break=params.silence_length_to_break, + overlap_ratio=params.overlap_ratio, min_duration=params.min_duration, max_duration=params.max_duration, expected_duration=params.expected_duration, @@ -457,9 +460,7 @@ def main(): batch_cuts = [] logging.info(f"Start processing...") for i, cut in enumerate(raw_cuts): - if len(batch_cuts) < params.batch_size: - batch_cuts.append(cut) - else: + if len(batch_cuts) >= params.batch_size: process_one_batch( params, batch_cuts=batch_cuts, @@ -469,6 +470,7 @@ def main(): ) batch_cuts = [] logging.info(f"Number of cuts have been loaded is {i}") + batch_cuts.append(cut) if len(batch_cuts): process_one_batch( params, diff --git a/pyproject.toml b/pyproject.toml index e0276fd..995faee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" [project] name = "fasttextsearch" -version = "0.9" +version = "0.10" authors = [ { name="Next-gen Kaldi development team", email="wkang.pku@gmail.com" }, ] diff --git a/textsearch/python/tests/CMakeLists.txt b/textsearch/python/tests/CMakeLists.txt index 2712e2b..2cd1b54 100644 --- a/textsearch/python/tests/CMakeLists.txt +++ b/textsearch/python/tests/CMakeLists.txt @@ -19,6 +19,7 @@ endfunction() if(TS_ENABLE_TESTS) set(test_srcs test_find_close_matches.py + test_is_overlap.py test_levenshtein_distance.py test_match.py test_row_ids_to_row_splits.py diff --git a/textsearch/python/tests/test_is_overlap.py b/textsearch/python/tests/test_is_overlap.py new file mode 100755 index 0000000..f920942 --- /dev/null +++ b/textsearch/python/tests/test_is_overlap.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this single test, use +# +# ctest --verbose -R match_test_py + +import unittest + +from textsearch.utils import is_overlap + + +class TestOverlap(unittest.TestCase): + def test_is_overlap(self): + candidates = [ + [20, 30], + [15, 25], + [10, 21.1], + [1, 10], + [60, 70], + [65, 73], + [68.5, 85], + [25, 35], + [45, 55], + [20, 25], + [21, 25], + [34.5, 46.5], + [35, 46.1], + [25, 35], + [26, 34], + [44, 70.5], + ] + selected_ranges: List[Tuple[float, float]] = [] + selected_indexes: List[int] = [] + segments = [] + overlapped_segments = [] + for r in candidates: + status, index = is_overlap( + selected_ranges, + selected_indexes, + query=(r[0], r[1]), + segment_index=len(segments), + overlap_ratio=0.1, + ) + if status: + if index is not None: + overlapped_segments.append(index) + segments.append(r) + else: + segments.append(r) + for index in sorted(overlapped_segments, reverse=True): + segments.pop(index) + expected_segments = [ + [10, 21.1], + [1, 10], + [68.5, 85], + [25, 35], + [21, 25], + [35, 46.1], + ] + assert segments == expected_segments + + +if __name__ == "__main__": + unittest.main() diff --git a/textsearch/python/textsearch/match.py b/textsearch/python/textsearch/match.py index 1f7be1d..f82477f 100644 --- a/textsearch/python/textsearch/match.py +++ b/textsearch/python/textsearch/match.py @@ -145,44 +145,6 @@ def _break_query( # [(query_start, query_end, target_start, target_end)] segments: List[Tuple[int, int, int, int]] = [] - def add_segments( - query_start, - query_end, - target_start, - target_end, - segment_length, - segments, - ): - num_chunk = (query_end - query_start) // segment_length - if num_chunk > 0: - for i in range(num_chunk): - real_target_end = ( - target_start + segment_length - if target_start + segment_length < target_end - else target_end - ) - segments.append( - ( - query_start, - query_start + segment_length, - target_start, - real_target_end, - ) - ) - query_start += segment_length - target_start += segment_length - # if the remaining part is smaller than segment_length // 4, we will - # append it to the last segment rather than creating a new segment. - if segments and query_end - query_start < segment_length // 4: - segments[-1] = ( - segments[-1][0], - query_end, - segments[-1][2], - target_end, - ) - else: - segments.append((query_start, query_end, target_start, target_end)) - target_doc_id = sourced_text.doc[matched_points[max_item[0]][1]] target_base = sourced_text.doc_splits[target_doc_id] next_target_base = sourced_text.doc_splits[target_doc_id + 1] @@ -207,7 +169,15 @@ def add_segments( for ind in range(max_item[0], max_item[1]): if matched_points[ind][0] - prev_break_point[0] > segment_length: if ind == max_item[0]: - continue + segments.append( + ( + prev_break_point[0], + matched_points[ind][0], + prev_break_point[1], + matched_points[ind][1], + ) + ) + prev_break_point = matched_points[ind] else: query_start = prev_break_point[0] query_end = matched_points[ind - 1][0] @@ -217,17 +187,16 @@ def add_segments( ratio = (target_end - target_start) / (query_end - query_start) half = reference_length_difference / 2 if ratio < 1 - half or ratio > 1 + half: + logging.debug( + f"Invalid ratio for segment: " + f"{query_start, query_end, target_start, target_end}" + ) continue - prev_break_point = (query_end, target_end) - add_segments( - query_start, - query_end, - target_start, - target_end, - segment_length, - segments, + segments.append( + (query_start, query_end, target_start, target_end) ) + prev_break_point = (query_end, target_end) query_start, target_start = prev_break_point query_end = next_query_base @@ -248,14 +217,7 @@ def add_segments( else: segments.append((query_start, query_end, target_start, target_end)) else: - add_segments( - query_start, - query_end, - target_start, - target_end, - segment_length, - segments, - ) + segments.append((query_start, query_end, target_start, target_end)) return segments @@ -470,15 +432,22 @@ def align_queries( # in sourced_text matched_points = get_longest_increasing_pairs(seq1, seq2) - if len(matched_points) == 0: - continue - # In the algorithm of `find_close_matches`, # `sourced_text.binary_text.size - 1` means no close_matches - trim_pos = len(matched_points) - 1 - while matched_points[trim_pos][1] == sourced_text.binary_text.size - 1: - trim_pos -= 1 - matched_points = matched_points[0:trim_pos] + if len(matched_points) != 0: + trim_pos = len(matched_points) - 1 + while ( + matched_points[trim_pos][1] == sourced_text.binary_text.size - 1 + ): + trim_pos -= 1 + matched_points = matched_points[0:trim_pos] + + if len(matched_points) == 0: + logging.warning( + f"Skipping query {q}, no matched points between query and target" + f"in close_matches." + ) + continue # The following code guarantees the matched points are in the same # reference document. We will choose the reference document that matches @@ -988,6 +957,7 @@ def _split_into_segments( preceding_context_length: int = 1000, timestamp_position: str = "middle", # previous, middle, current silence_length_to_break: float = 0.6, # in second + overlap_ratio: float = 0.35, # percentage min_duration: float = 2, # in second max_duration: float = 30, # in second expected_duration: Tuple[float, float] = (5, 20), # in second @@ -1024,6 +994,10 @@ def _split_into_segments( preceding or succeeding silence length greater than this value, we will add it as a possible breaking point. Caution: Only be used when there are no punctuations in target_source. + overlap_ratio: + The ratio of overlapping part to the query or existing segments. If the + ratio is greater than `overlap_ratio` we will drop the query or existing + segment. min_duration: The minimum duration (in second) allowed for a segment. max_duration: @@ -1079,13 +1053,28 @@ def _split_into_segments( # Handle the overlapping # Caution: Don't modified selected_ranges, it will be manipulated in # `is_overlap` and will be always kept sorted. - selected_ranges: List[Tuple[int, int]] = [] + # Don't modified selected_indexes also, it will be manipulated in `is_overlap` + # according to selected_ranges. + selected_ranges: List[Tuple[float, float]] = [] + selected_indexes: List[int] = [] segments = [] + overlapped_segments = [] for r in candidates: - if not is_overlap( - selected_ranges, query=(r[0], r[1]), overlap_ratio=0.5 - ): + status, index = is_overlap( + selected_ranges, + selected_indexes, + query=(aligns[r[0]]["hyp_time"], aligns[r[1]]["hyp_time"]), + segment_index=len(segments), + overlap_ratio=overlap_ratio, + ) + if status: + if index is not None: + overlapped_segments.append(index) + segments.append(r) + else: segments.append(r) + for index in sorted(overlapped_segments, reverse=True): + segments.pop(index) results = [] @@ -1217,6 +1206,7 @@ def _split_helper( preceding_context_length: int, timestamp_position: str, silence_length_to_break: float, + overlap_ratio: float, min_duration: float, max_duration: float, expected_duration: Tuple[float, float], @@ -1233,6 +1223,7 @@ def _split_helper( preceding_context_length=preceding_context_length, timestamp_position=timestamp_position, silence_length_to_break=silence_length_to_break, + overlap_ratio=overlap_ratio, min_duration=min_duration, max_duration=max_duration, expected_duration=expected_duration, @@ -1250,6 +1241,7 @@ def split_aligned_queries( preceding_context_length: int = 1000, timestamp_position: str = "current", # previous, middle, current silence_length_to_break: float = 0.6, # in second + overlap_ratio: float = 0.35, min_duration: float = 2, # in second max_duration: float = 30, # in second expected_duration: Tuple[float, float] = (5, 20), # in second @@ -1288,6 +1280,10 @@ def split_aligned_queries( preceding or succeeding silence length greater than this value, we will add it as a possible breaking point. Caution: Only be used when there are no punctuations in target_source. + overlap_ratio: + The ratio of overlapping part to the query or existing segments. If the + ratio is greater than `overlap_ratio` we will drop the query or existing + segment. min_duration: The minimum duration (in second) allowed for a segment. max_duration: @@ -1342,6 +1338,7 @@ def split_aligned_queries( preceding_context_length, timestamp_position, silence_length_to_break, + overlap_ratio, min_duration, max_duration, expected_duration, diff --git a/textsearch/python/textsearch/utils.py b/textsearch/python/textsearch/utils.py index d78aa87..72ecfb5 100644 --- a/textsearch/python/textsearch/utils.py +++ b/textsearch/python/textsearch/utils.py @@ -10,8 +10,8 @@ Pathlike = Union[str, Path] PUCTUATIONS = { - "all": set("',.;?!():-<>/\",。;?!():-《》【】”“"), - "eos": set(".?!。?!"), + "all": set("'.;?!():-<>/\"。;?!():-《》【】”“"), + "eos": set(".?,,!。?!"), "left": set("\"'(<《【“"), "right": set("\"')>》】”"), } @@ -108,53 +108,101 @@ def row_ids_to_row_splits(row_ids: np.ndarray) -> np.ndarray: def is_overlap( - ranges: List[Tuple[int, int]], - query: Tuple[int, int], + ranges: List[Tuple[float, float]], + indexes: List[int], + query: Tuple[float, float], + segment_index: int, overlap_ratio: float = 0.25, -) -> bool: +) -> Tuple[bool, Union[int, None]]: """ - Return if the given range overlaps with the existing ranges. + Return True if the given range overlaps with the existing ranges. Caution: - `ranges` will be modified in this function (when returning False) + `ranges` and `indexes` will be modified in this function. Note: overlapping here means the length of overlapping area is greater than some threshold (currently, the threshold is `overlap_ratio` multiply the length - of the shorter overlapping ranges). + of the query or existing ranges). Args: ranges: The existing ranges, it is sorted in ascending order on input, and we will keep it sorted in this function. + indexes: + The index (into the selected segments) of each range belongs to. query: The given range. + segment_index: + The index (into the selected segments) of query to be inserted. + overlap_ratio: + The ratio of overlapping part to the query or existing segments. If the + ratio is greater than `overlap_ratio` we will drop the query or existing + segment. Return: - Return True if having overlap otherwise False. + Return (False, None) if no overlapping between query and existing ranges. + Return (True, None) if the ratio of overlapping part to query is greater + than `overlap_ratio`. + Return (True, dindex) if the ratio of overlapping part to existing range + is greater than `overlap_ratio`, `dindex` is the index (can get from indexes) + of the existing range. """ - is_overlap = False index = bisect_left(ranges, query) - if index == 0: - if ranges: - is_overlap = ( - query[1] - ranges[0][0] > (query[1] - query[0]) * overlap_ratio - ) - elif index == len(ranges): + if not ranges: + ranges.insert(index, query) + indexes.insert(index, segment_index) + return False, None + + # overlapping on query + if index > 0: is_overlap = ( ranges[index - 1][1] - query[0] > (query[1] - query[0]) * overlap_ratio ) - else: + if is_overlap: + return True, None + + if index < len(ranges): is_overlap = ( - ranges[index - 1][1] - query[0] - > (query[1] - query[0]) * overlap_ratio - ) or ( query[1] - ranges[index][0] > (query[1] - query[0]) * overlap_ratio ) + if is_overlap: + return True, None - if not is_overlap: + # overlapping on existing ranges + is_overlap_left = False + if index > 0: + is_overlap_left = ( + ranges[index - 1][1] - query[0] + > (ranges[index - 1][1] - ranges[index - 1][0]) * overlap_ratio + ) + + is_overlap_right = False + if index < len(ranges): + is_overlap_right = ( + query[1] - ranges[index][0] + > (ranges[index][1] - ranges[index][0]) * overlap_ratio + ) + + if is_overlap_left or is_overlap_right: + if is_overlap_left and not is_overlap_right: + ranges.insert(index, query) + ranges.pop(index - 1) + indexes.insert(index, segment_index) + dindex = indexes.pop(index - 1) + return True, dindex + elif is_overlap_right and not is_overlap_left: + ranges.insert(index, query) + ranges.pop(index + 1) + indexes.insert(index, segment_index) + dindex = indexes.pop(index + 1) + return True, dindex + else: + return True, None + else: ranges.insert(index, query) - return is_overlap + indexes.insert(index, segment_index) + return False, None def is_punctuation(c: str, eos_only: bool = False) -> bool: @@ -165,7 +213,7 @@ def is_punctuation(c: str, eos_only: bool = False) -> bool: c: The given character. eos_only: - If True the punctuations are only those indicating end of a sentence (.?! for now). + If True the punctuations are only those indicating end of a sentence (,.?! for now). """ if eos_only: return c in PUCTUATIONS["eos"]