diff --git a/.gitignore b/.gitignore index a7df95a7..d6f29145 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,4 @@ node_modules .DS_Store ._.DS_Store docs/tutorials/mouse_biccn.ipynb -docs/tutorials/.ipynb_checkpoints \ No newline at end of file +.ipynb_checkpoints/ diff --git a/src/crested/__init__.py b/src/crested/__init__.py index 1302b1ed..169e1035 100644 --- a/src/crested/__init__.py +++ b/src/crested/__init__.py @@ -25,7 +25,7 @@ from . import pl, pp, tl, utils from ._datasets import get_dataset, get_model, get_motif_db from ._genome import Genome, register_genome -from ._io import import_beds, import_bigwigs +from ._io import import_beds, import_bigwigs, import_bigwigs_raw, read_lazy_h5ad __all__ = [ "pl", diff --git a/src/crested/_io.py b/src/crested/_io.py index 9e741ebc..713bac1f 100644 --- a/src/crested/_io.py +++ b/src/crested/_io.py @@ -16,7 +16,7 @@ from anndata import AnnData from loguru import logger from scipy.sparse import csr_matrix - +from tqdm import tqdm from crested import _conf as conf @@ -61,9 +61,9 @@ def _custom_region_sort(region: str) -> tuple[int, int, int]: def _read_chromsizes(chromsizes_file: PathLike) -> dict[str, int]: """Read chromsizes file into a dictionary.""" chromsizes = pd.read_csv( - chromsizes_file, sep="\t", header=None, names=["chr", "size"] + chromsizes_file, sep="\t", header=None, names=["chrom", "size"] ) - chromsizes_dict = chromsizes.set_index("chr")["size"].to_dict() + chromsizes_dict = chromsizes.set_index("chrom")["size"].to_dict() return chromsizes_dict @@ -90,12 +90,12 @@ def _extract_values_from_bigwig( if chrom in chromosomes_in_bigwig: temp_bed_file.file.write(line.encode("utf-8")) bed_entries_to_keep_idx.append(idx) - # Make sure all content is written to temporary BED file. temp_bed_file.file.flush() total_bed_entries = idx + 1 bed_entries_to_keep_idx = np.array(bed_entries_to_keep_idx, np.intp) + # Branch logic for target if target == "mean": with pybigtools.open(bw_file, "r") as bw: values = np.fromiter( @@ -124,20 +124,31 @@ def _extract_values_from_bigwig( dtype=np.float32, ) ) + else: raise ValueError(f"Unsupported target '{target}'") - # Remove temporary BED file. temp_bed_file.close() - if values.shape[0] != total_bed_entries: - # Set all values for BED entries for which the chromosome was not in in the bigWig file to NaN. - all_values = np.full(total_bed_entries, np.nan, dtype=np.float32) - all_values[bed_entries_to_keep_idx] = values - return all_values - else: - return values + # Now handle missing chromosome lines + if target == "raw": + # 'values' is 2D [n_valid_regions, region_length or max_length]. + # We have total_bed_entries lines in original bed, but only n_valid in 'values'. + # If we want to keep shape [n_regions, region_length] for the final result, + # we must build a bigger 2D array with shape [total_bed_entries, max_length]. + all_data = np.full((total_bed_entries, values.shape[1]), np.nan, dtype=np.float32) + all_data[bed_entries_to_keep_idx, :] = values + return all_data + else: + # 'values' is 1D + if values.shape[0] != total_bed_entries: + # Set all values for BED entries for which the chromosome was not in bigWig to NaN. + all_values = np.full(total_bed_entries, np.nan, dtype=np.float32) + all_values[bed_entries_to_keep_idx] = values + return all_values + else: + return values def _extract_tracks_from_bigwig( bw_file: PathLike, @@ -221,14 +232,10 @@ def _extract_tracks_from_bigwig( def _read_consensus_regions( - regions_file: PathLike, chromsizes_file: PathLike | None = None + regions_file: PathLike, chromsizes_dict: dict | None = None ) -> pd.DataFrame: """Read consensus regions BED file and filter out regions not within chromosomes.""" - if chromsizes_file is not None: - chromsizes_file = Path(chromsizes_file) - if not chromsizes_file.is_file(): - raise FileNotFoundError(f"File '{chromsizes_file}' not found") - if chromsizes_file is None and not conf.genome: + if chromsizes_dict is None and not conf.genome: logger.warning( "Chromsizes file not provided. Will not check if regions are within chromosomes", stacklevel=1, @@ -240,24 +247,25 @@ def _read_consensus_regions( usecols=[0, 1, 2], dtype={0: str, 1: "Int32", 2: "Int32"}, ) + consensus_peaks.columns = ["chrom","start","end"] consensus_peaks["region"] = ( - consensus_peaks[0].astype(str) + consensus_peaks["chrom"].astype(str) + ":" - + consensus_peaks[1].astype(str) + + consensus_peaks["start"].astype(str) + "-" - + consensus_peaks[2].astype(str) + + consensus_peaks["end"].astype(str) ) - if chromsizes_file: - chromsizes_dict = _read_chromsizes(chromsizes_file) + if chromsizes_dict: + pass elif conf.genome: chromsizes_dict = conf.genome.chrom_sizes else: return consensus_peaks valid_mask = consensus_peaks.apply( - lambda row: row[0] in chromsizes_dict - and row[1] >= 0 - and row[2] <= chromsizes_dict[row[0]], + lambda row: row["chrom"] in chromsizes_dict + and row["start"] >= 0 + and row["end"] <= chromsizes_dict[row[0]], axis=1, ) consensus_peaks_filtered = consensus_peaks[valid_mask] @@ -270,18 +278,18 @@ def _read_consensus_regions( def _create_temp_bed_file( - consensus_peaks: pd.DataFrame, target_region_width: int | None + consensus_peaks: pd.DataFrame, target_region_width: int, adjust = True ) -> str: """Adjust consensus regions to a target width and create a temporary BED file.""" adjusted_peaks = consensus_peaks.copy() - if target_region_width: + if adjust: adjusted_peaks[1] = adjusted_peaks.apply( lambda row: max(0, row[1] - (target_region_width - (row[2] - row[1])) // 2), axis=1, ) adjusted_peaks[2] = adjusted_peaks[1] + target_region_width - adjusted_peaks[1] = adjusted_peaks[1].astype(int) - adjusted_peaks[2] = adjusted_peaks[2].astype(int) + adjusted_peaks[1] = adjusted_peaks[1].astype(int) + adjusted_peaks[2] = adjusted_peaks[2].astype(int) # Create a temporary BED file temp_bed_file = "temp_adjusted_regions.bed" @@ -310,7 +318,7 @@ def _check_bed_file_format(bed_file: PathLike) -> None: def import_beds( beds_folder: PathLike, regions_file: PathLike | None = None, - chromsizes_file: PathLike | None = None, + chromsizes_dict: dict | None = None, classes_subset: list | None = None, remove_empty_regions: bool = True, compress: bool = False, @@ -340,8 +348,8 @@ def import_beds( List of classes to include in the AnnData object. If None, all files will be included. Classes should be named after the file name without the extension. - chromsizes_file - File path of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries. + chromsizes_dict + dict of chromsizes. Used for checking if the new regions are within the chromosome boundaries. If not provided, will look for a registered genome object. remove_empty_regions Remove regions that are not open in any class (only possible if regions_file is provided) @@ -361,7 +369,6 @@ def import_beds( >>> anndata = crested.import_beds( ... beds_folder="path/to/beds/folder/", ... regions_file="path/to/regions.bed", - ... chromsizes_file="path/to/chrom.sizes", ... classes_subset=["Topic_1", "Topic_2"], ... ) """ @@ -381,7 +388,7 @@ def import_beds( if regions_file: # Read consensus regions BED file and filter out regions not within chromosomes _check_bed_file_format(regions_file) - consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) + consensus_peaks = _read_consensus_regions(regions_file, chromsizes_dict) binary_matrix = pd.DataFrame(0, index=[], columns=consensus_peaks["region"]) file_paths = [] @@ -462,7 +469,7 @@ def import_beds( ann_data.obs["file_path"] = file_paths ann_data.obs["n_open_regions"] = ann_data.X.sum(axis=1) ann_data.var["n_classes"] = ann_data.X.sum(axis=0) - ann_data.var["chr"] = ann_data.var.index.str.split(":").str[0] + ann_data.var["chrom"] = ann_data.var.index.str.split(":").str[0] ann_data.var["start"] = ( ann_data.var.index.str.split(":").str[1].str.split("-").str[0] ).astype(int) @@ -493,7 +500,7 @@ def import_beds( def import_bigwigs( bigwigs_folder: PathLike, regions_file: PathLike, - chromsizes_file: PathLike | None = None, + chromsizes_dict: dict | None = None, target: str = "mean", target_region_width: int | None = None, compress: bool = False, @@ -514,11 +521,11 @@ def import_bigwigs( Folder name containing the bigWig files. regions_file File name of the consensus regions BED file. - chromsizes_file - File name of the chromsizes file. Used for checking if the new regions are within the chromosome boundaries. + chromsizes_dict + Dictionary of chrom sizes. Used for checking if the new regions are within the chromosome boundaries. If not provided, will look for a registered genome object. target - Target value to extract from bigwigs. Can be 'mean', 'max', 'count', or 'logcount' + Target value to extract from bigwigs. Can be 'raw', 'mean', 'max', 'count', or 'logcount' target_region_width Width of region that the bigWig target value will be extracted from. If None, the consensus region width will be used. @@ -551,7 +558,7 @@ def import_bigwigs( # Read consensus regions BED file and filter out regions not within chromosomes _check_bed_file_format(regions_file) - consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) + consensus_peaks = _read_consensus_regions(regions_file, chromsizes_file) bed_file = _create_temp_bed_file(consensus_peaks, target_region_width) @@ -575,18 +582,22 @@ def import_bigwigs( # Process bigWig files in parallel and collect the results logger.info(f"Extracting values from {len(bw_files)} bigWig files...") all_results = [] - with ProcessPoolExecutor() as executor: - futures = [ - executor.submit( - _extract_values_from_bigwig, - bw_file, - bed_file, - target, - ) - for bw_file in bw_files - ] - for future in futures: - all_results.append(future.result()) + # with ProcessPoolExecutor() as executor: + # futures = [ + # executor.submit( + # _extract_values_from_bigwig, + # bw_file, + # bed_file, + # target, + # ) + # for bw_file in bw_files + # ] + # for future in futures: + # all_results.append(future.result()) + + for bw_file in bw_files: + result = _extract_values_from_bigwig(bw_file, bed_file, target=target) + all_results.append(result) os.remove(bed_file) @@ -603,7 +614,7 @@ def import_bigwigs( var_df = pd.DataFrame( { "region": consensus_peaks["region"], - "chr": consensus_peaks["region"].str.split(":").str[0], + "chrom": consensus_peaks["region"].str.split(":").str[0], "start": ( consensus_peaks["region"].str.split(":").str[1].str.split("-").str[0] ).astype(int), @@ -627,3 +638,664 @@ def import_bigwigs( ) return adata + + +import scipy +from scipy.sparse import csr_matrix +import anndata +from anndata._io.h5ad import read_elem, read_dataframe +from pathlib import Path +from typing import Literal +from collections.abc import Callable, Collection, Mapping, Sequence +from typing import Any, Literal +from anndata._io.h5ad import * +from anndata import AnnData +import numpy as np +import h5py +import pybigtools +import numpy as np +from pathlib import Path +from os import PathLike +from crested._genome import Genome + + +class H5Source: + """ + A lightweight reference to a dataset in an HDF5 file on disk. + It does NOT hold data in memory, only the filename/path. + """ + def __init__(self, filename: str, dataset_path: str): + self.filename = filename + self.dataset_path = dataset_path + + @property + def shape(self): + with h5py.File(self.filename, "r") as f: + return f[self.dataset_path].shape + + def __getitem__(self, idx): + with h5py.File(self.filename, "r") as f: + return f[self.dataset_path][idx] + + +class LazyTensor(np.ndarray): + """ + A 2D "indexable" view of a 3D+ dataset [rows, cols, coverage_length,...]. + For each (row, col), we retrieve a 1D coverage array of shape [coverage_length]. + """ + + def __new__( + cls, + source, + row_labels=None, + col_labels=None, + ): + """ + source: An H5Source or similar, expected shape = (n_rows, n_cols, coverage_length). + row_labels: optional list of row (track) labels + col_labels: optional list of column (region) labels + """ + obj = super().__new__(cls, shape=(), dtype=float) + obj.source = source # e.g. H5Source + # We define the 2D "index shape" as (source.shape[0], source.shape[1]) + # coverage is the 3rd dimension + obj._lazy_shape = (obj.source.shape[0], obj.source.shape[1]) + + # Create label -> index maps if provided + if row_labels: + obj.row_labels = {label: i for i, label in enumerate(row_labels)} + else: + obj.row_labels = None + + if col_labels: + obj.col_labels = {label: i for i, label in enumerate(col_labels)} + else: + obj.col_labels = None + + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.source = getattr(obj, "source", None) + self._lazy_shape = getattr(obj, "_lazy_shape", None) + self.row_labels = getattr(obj, "row_labels", None) + self.col_labels = getattr(obj, "col_labels", None) + + def _get_indices(self, labels, label_map): + """ + Convert label-based indices into integer indices. + Returns array of float so we can store np.nan for missing labels. + """ + if label_map is None: + raise IndexError("Label-based indexing is not supported on this axis.") + indices = [] + for label in labels: + if label in label_map: + indices.append(label_map[label]) + else: + indices.append(np.nan) # missing + return np.array(indices, dtype=float) + + def __getitem__(self, index): + """ + Expects 2D indexing, e.g. lazy_tensor[row_idx, col_idx], + returning a coverage array of shape (len(row_idx), len(col_idx), coverage_length). + """ + if not isinstance(index, tuple) or len(index) != 2: + raise IndexError("LazyTensor expects 2D indexing, e.g. lazy_tensor[i, j].") + + row_index, col_index = index + + # 1) parse row_index + if isinstance(row_index, str): + row_index = self._get_indices([row_index], self.row_labels) + elif isinstance(row_index, (list, np.ndarray)) and len(row_index) > 0 and isinstance(row_index[0], str): + row_index = self._get_indices(row_index, self.row_labels) + elif isinstance(row_index, slice): + row_index = np.arange(*row_index.indices(self._lazy_shape[0]), dtype=float) + elif isinstance(row_index, (int, np.integer)): + row_index = np.array([row_index], dtype=float) + else: + row_index = np.array(row_index, dtype=float) + + # 2) parse col_index + if isinstance(col_index, str): + col_index = self._get_indices([col_index], self.col_labels) + elif isinstance(col_index, (list, np.ndarray)) and len(col_index) > 0 and isinstance(col_index[0], str): + col_index = self._get_indices(col_index, self.col_labels) + elif isinstance(col_index, slice): + col_index = np.arange(*col_index.indices(self._lazy_shape[1]), dtype=float) + elif isinstance(col_index, (int, np.integer)): + col_index = np.array([col_index], dtype=float) + else: + col_index = np.array(col_index, dtype=float) + + # 3) Build an output array => (n_rows, n_cols, coverage_length) + coverage_length = self.source.shape[2] # the 3rd dimension + out_shape = (len(row_index), len(col_index), coverage_length) + + result = np.empty(out_shape, dtype=np.float32) + result.fill(np.nan) + + # 4) read from source for each (r, c) + for i, r in enumerate(row_index): + for j, c in enumerate(col_index): + if not np.isnan(r) and not np.isnan(c): + r_int = int(r) + c_int = int(c) + # read coverage from [r_int, c_int, :] + data_1d = self.source[r_int, c_int] + # data_1d is shape = (coverage_length,) + result[i, j, :] = data_1d + + return result + + @property + def shape(self): + # The 2D "index shape" is (n_tracks, n_regions) + return self._lazy_shape + + def __repr__(self): + return f"" + +class LazyMatrix(np.ndarray): + """ + A 2D NumPy array subclass. Indexing behaves like a normal matrix: + - [i, j] => a scalar + - [i, :] => a 1D array + - [i:j, p:q] => a 2D array + The actual data is fetched on demand from `source`. + """ + + def __new__(cls, source, dtype=float): + obj = super().__new__(cls, shape=(), dtype=dtype) + obj.source = source + obj._lazy_shape = source.shape # must be 2D + return obj + + def __array_finalize__(self, obj): + if obj is None: + return + self.source = getattr(obj, 'source', None) + self._lazy_shape = getattr(obj, '_lazy_shape', None) + + @property + def shape(self): + return self._lazy_shape # e.g. (n_vars, n_vars) + + def __getitem__(self, idx): + """ + Standard 2D indexing. + The result is a scalar, 1D, or 2D np.ndarray in memory. + """ + data = self.source[idx] # Read from disk + data = np.asarray(data, dtype=self.dtype) + return data + + def __repr__(self): + return f"" + + +class LazyAnnData(anndata.AnnData): + """ + Subclass that can: + 1) Link `adata.X` to a 3D LazyTensor referencing `uns[track_key]`. + 2) Turn certain varp keys into 2D LazyMatrix references on disk. + """ + + def __init__( + self, + *args, + track_key='tracks', # The key in uns that will feed a 3D LazyTensor for X + lazy_varp_keys=None, # A list of keys in varp to store as 2D LazyMatrix + **kwargs + ): + super().__init__(*args, **kwargs) + self._track_key = track_key + self._lazy_varp_keys = set(lazy_varp_keys or []) + + # Post-process. If track_key is set, link X => uns[track_key] + self._make_tracks_lazy() + + # If there are varp keys to be lazy, replace them + self._make_varp_lazy() + + def _make_tracks_lazy(self): + """ + If `track_key` is provided and exists in uns, we replace `self.X` + with a 3D LazyTensor that references the on-disk dataset at uns[track_key]. + """ + if self.filename is None: + return + if self._track_key is None: + return + + import h5py + with h5py.File(self.filename, 'r') as f: + if "uns" not in f: + return + if self._track_key in f["uns"]: + dataset_path = f"uns/{self._track_key}" + h5src = H5Source(self.filename, dataset_path) + # e.g. shape=(N, N, M) or any logic you want: + # Suppose the dataset is shape (N, M). We'll interpret that as (N, N, M): + # shape = (self.shape[0], self.shape[1], h5src.shape[1]) + lazy_tens = LazyTensor(source=h5src, row_labels = list(self.obs_names), col_labels = list(self.var_names)) #shape=shape + self.layers[self._track_key] = lazy_tens + + def _make_varp_lazy(self): + """ + For each key in _lazy_varp_keys, replace varp[key] with a 2D LazyMatrix. + """ + if self.filename is None: + return + import h5py + with h5py.File(self.filename, 'r') as f: + if "varp" not in f: + return + for key in self._lazy_varp_keys: + if key in f["varp"]: + dataset_path = f"varp/{key}" + h5src = H5Source(self.filename, dataset_path) + lazy_mat = LazyMatrix(h5src) + self._varp[key] = lazy_mat + + def write(self, **kwargs): + """ + In-place partial update, ensuring the file is in 'r+' mode. + Example: rewriting `obs`, `var`, and `obsm`, + but skipping huge `uns` or `varp`. + """ + if self.filename is None: + raise ValueError("No filename set. Can't do partial writes.") + + import h5py + from anndata._io.h5ad import write_elem + + with h5py.File(self.filename, 'r+') as f: + # Example: update `obs` if it's stored as a group + if 'obs' in f and isinstance(f['obs'], h5py.Group): + del f['obs'] # remove old + write_elem(f,'obs',self.obs) + + # Example: update `var` + if 'var' in f and isinstance(f['var'], h5py.Group): + del f['var'] + write_elem(f,'var', self.var) + + # Overwrite `obsm` by iterating items manually + if 'obsm' in f: + del f['obsm'] + obsm_group = f.create_group('obsm') + # Each entry in self.obsm is typically a 2D array or similar + for key, arr in self.obsm.items(): + write_elem(obsm_group, key, arr) + +def _write_raw_bigwigs_to_uns( + h5ad_filename: str, + consensus_peaks: pd.DataFrame, + bigwig_files: list[str], + track_key: str = "tracks", + chunk_size: int = 1024, +) -> None: + """ + Creates and populates an HDF5 dataset uns[track_key] in the .h5ad file with shape: + (n_tracks, n_regions, max_length) + + - n_tracks = number of bigwig files + - n_regions = number of rows in consensus_peaks + - max_length = max(end - start) across all peaks + + Data is written row-by-row to avoid OOM. Each row is one bigWig file, + and we iterate over the consensus_peaks in slices of size chunk_size. + + This is compatible with the LazyTensor class which expects row=track, col=region. + + Parameters + ---------- + h5ad_filename : str + The path to the existing .h5ad file (with minimal placeholder AnnData). + consensus_peaks : pd.DataFrame + DataFrame of consensus peaks. Must have at least columns 0,1,2 = chrom,start,end. + bigwig_files : list[str] + A sorted list of bigwig file paths (one per track). + track_key : str + The key in uns[track_key] to write the raw coverage dataset. + chunk_size : int + For memory reasons, we process the peaks in slices of this size. + """ + # 1) figure out n_regions and max_len + n_regions = len(consensus_peaks) + max_len = 0 + # We assume columns [0,1,2] = (chrom, start, end), or you can adapt if you have named columns + for row_i in range(n_regions): + chrom = consensus_peaks.iat[row_i, 0] + start = consensus_peaks.iat[row_i, 1] + end = consensus_peaks.iat[row_i, 2] + length = int(end) - int(start) + if length > max_len: + max_len = length + + n_tracks = len(bigwig_files) + + # 2) open h5ad in r+ mode to create or overwrite uns[track_key] + with h5py.File(h5ad_filename, "r+") as f: + if "uns" not in f: + f.create_group("uns") + uns_grp = f["uns"] + + # remove old data if present + if track_key in uns_grp: + del uns_grp[track_key] + + # create dataset => shape (n_tracks, n_regions, max_len) + # so row = track_i, col = region_i, last dim = coverage + dset = uns_grp.create_dataset( + track_key, + shape=(n_tracks, n_regions, max_len), + dtype="float32", + chunks=(n_tracks, chunk_size, max_len), # you can tune chunk shapes + fillvalue=np.nan, + ) + + # 3) fill coverage row by row, chunk by chunk + for track_i, bw_file in enumerate(bigwig_files): + logger.info(f"Filling coverage for track {track_i+1}/{n_tracks}: {bw_file}") + + with pybigtools.open(bw_file, "r") as bw: + cur_chroms = bw.chroms().keys() + # read coverage in slices of size=chunk_size + for start_idx in tqdm(range(0, n_regions, chunk_size)): + end_idx = min(start_idx + chunk_size, n_regions) + batch_size = end_idx - start_idx + + # retrieve coverage for each region in [start_idx, end_idx) + batch_signals = [] + local_max_len = 0 + for reg_i in range(start_idx, end_idx): + chrom = consensus_peaks.iat[reg_i, 0] + start = consensus_peaks.iat[reg_i, 1] + end = consensus_peaks.iat[reg_i, 2] + if chrom in cur_chroms: + signal = bw.values( + str(chrom), + int(start), + int(end), + # exact=False, + # missing=np.nan, + oob=np.nan, + ) + else: + signal = np.zeros(int(end-start)) + batch_signals.append(signal) + if len(signal) > local_max_len: + local_max_len = len(signal) + + # clamp local_max_len by global max_len + if local_max_len > max_len: + local_max_len = max_len + + # pad and write each coverage array + for offset, signal in enumerate(batch_signals): + region_i = start_idx + offset + length = len(signal) + if length > max_len: + length = max_len + # write to dset[track_i, region_i, :length] + dset[track_i, region_i, :length] = signal[:length].astype("float32") + + logger.info( + f"Done writing raw coverage to uns[{track_key}] in {h5ad_filename}.\n" + f" Shape = (n_tracks={n_tracks}, n_regions={n_regions}, max_len={max_len})" + ) + + +def read_lazy_h5ad( + filename: str | Path, + mode: Literal["r", "r+"] = "r", + track_key=None, + lazy_varp_keys=None, +) -> LazyAnnData: + """ + Reads an h5ad into a LazyAnnData. + - Should have summary statistic X + - Must have the tracks layer as uns + - If track_key is set (e.g. "tracks"), we link adata.X => uns[track_key] as a 3D LazyTensor. + - If lazy_varp_keys is set, each varp[key] is replaced by a 2D LazyMatrix. + """ + import h5py + + filename = str(filename) + init_kwargs = { + "filename": filename, + "filemode": mode, + "track_key": track_key, + "lazy_varp_keys": lazy_varp_keys, + } + + with h5py.File(filename, mode) as f: + from anndata._io.h5ad import read_elem, read_dataframe + attributes = ["obsm"] + df_attributes = ["obs", "var"] + + # Minimal logic from anndata + if "encoding-type" in f.attrs: + attributes.extend(df_attributes) + else: + for k in df_attributes: + if k in f: + init_kwargs[k] = read_dataframe(f[k]) + + for attr in attributes: + if attr in f: + init_kwargs[attr] = read_elem(f[attr]) + try: + init_kwargs['uns'] = {} + init_kwargs['uns']['params'] = read_elem(f['uns']['params']) + except: + pass + + # Now create the LazyAnnData, which calls _make_tracks_lazy and _make_varp_lazy + adata = LazyAnnData(**init_kwargs) + return adata + +def filter_and_adjust_chromosome_data( + peaks: pd.DataFrame, + chrom_sizes: dict, + max_shift: int = 0, + chrom_col: str = "chrom", + start_col: str = "start", + end_col: str = "end", + MIN_POS: int = 0, +) -> pd.DataFrame: + """ + Expand each peak by `max_shift` on both sides if possible. + If the peak is near the left edge, the leftover shift is added to the right side; + if near the right edge, leftover shift is added to the left side. + + Returns a DataFrame where each row is expanded (unless blocked by chromosome edges). + Rows with an unknown chromosome (i.e., not found in chrom_sizes) are dropped. + + Example: + If a row is: chr1, start=0, end=2114, max_shift=50 + => desired new length = 2114 + 2*50 = 2214 + => final row: (chr1, 0, 2214) + """ + + # 1) Map each row's chromosome to its size + # Rows with missing chrom sizes become NaN → we drop them. + peaks["_chr_size"] = peaks[chrom_col].map(chrom_sizes) + peaks = peaks.dropna(subset=["_chr_size"]).copy() + peaks["_chr_size"] = peaks["_chr_size"].astype(int) + + # Convert to arrays for fast vectorized arithmetic + starts = peaks[start_col].to_numpy(dtype=int) + ends = peaks[end_col].to_numpy(dtype=int) + chr_sizes_arr = peaks["_chr_size"].to_numpy(dtype=int) + + # Original length + orig_length = ends - starts + desired_length = orig_length + 2 * max_shift + + # 2) Temporarily shift left by max_shift + new_starts = starts - max_shift + new_ends = new_starts + desired_length # (so that final length = desired_length) + + # 3) If new_start < MIN_POS, shift leftover to the right + cond_left_edge = new_starts < MIN_POS + # How far below MIN_POS did we go? + shift_needed = MIN_POS - new_starts[cond_left_edge] # positive number + new_starts[cond_left_edge] = MIN_POS + new_ends[cond_left_edge] += shift_needed + + # 4) If new_end > chr_size, shift leftover to the left + cond_right_edge = new_ends > chr_sizes_arr + # How far beyond chromosome size did we go? + shift_needed = new_ends[cond_right_edge] - chr_sizes_arr[cond_right_edge] + new_ends[cond_right_edge] = chr_sizes_arr[cond_right_edge] + new_starts[cond_right_edge] -= shift_needed + + # 5) If shifting back on the left made new_start < MIN_POS again, clamp it. + cond_left_clamp = new_starts < MIN_POS + new_starts[cond_left_clamp] = MIN_POS + + # Assign back to DataFrame + peaks[start_col] = new_starts + peaks[end_col] = new_ends + + peaks.drop(columns=["_chr_size"], inplace=True) + + return peaks + + +def import_bigwigs_raw( + bigwigs_folder: PathLike, + regions_file: PathLike, + h5ad_path: PathLike, + target_region_width: int | None, + chromsizes_file: PathLike | None = None, + genome: Genome | None = None, + max_stochastic_shift: int = 0, + chunk_size: int = 512, +) -> AnnData: + """ + Import bigWig files and consensus regions BED file into AnnData format. + + This format is required to be able to train a basepair level prediction model. + The bigWig files target values are calculated for each region and and imported into an AnnData object, + with the bigWig file names as .obs and the consensus regions as .var. + Target region width can is specified to ensure region widths are of equal size, + no ragged tensors allowed. + This is often useful to extract sequence information around the actual peak region. + + Parameters + ---------- + bigwigs_folder + Folder name containing the bigWig files. + regions_file + File name of the consensus regions BED file. + file_path + Path where the anndata file will be backed + chromsizes_dict + Chromsizes dictionary. Used for checking if the new regions are within the chromosome boundaries. + If not provided, will look for a registered genome object. + target_region_width + Width of region that the bigWig target value will be extracted from. If None, the + consensus region width will be used. + + Returns + ------- + AnnData object with bigWigs as rows and peaks as columns. + + Example + ------- + >>> anndata = crested.import_bigwigs( + ... bigwigs_folder="path/to/bigwigs", + ... regions_file="path/to/peaks.bed", + ... chromsizes_file="path/to/chrom.sizes", + ... target="max", + ... target_region_width=500, + ... ) + """ + bigwigs_folder = Path(bigwigs_folder) + regions_file = Path(regions_file) + + # Input checks + if not bigwigs_folder.is_dir(): + raise FileNotFoundError(f"Directory '{bigwigs_folder}' not found") + if not regions_file.is_file(): + raise FileNotFoundError(f"File '{regions_file}' not found") + if chromsizes_file is not None: + chromsizes_dict = _read_chromsizes(chromsizes_file) + if genome is not None: + chromsizes_dict = genome.chrom_sizes + + # Read consensus regions BED file and filter out regions not within chromosomes + _check_bed_file_format(regions_file) + consensus_peaks = _read_consensus_regions(regions_file, chromsizes_dict) + region_width = int(np.round(np.mean(consensus_peaks['end'] - consensus_peaks['start']))) + consensus_peaks = consensus_peaks.loc[(consensus_peaks['end']-consensus_peaks['start']) == region_width,:] + consensus_peaks = filter_and_adjust_chromosome_data(consensus_peaks, chromsizes_dict, max_shift=max_stochastic_shift) + shifted_width = (target_region_width+2*max_stochastic_shift) + consensus_peaks = consensus_peaks.loc[(consensus_peaks['end']-consensus_peaks['start']) == shifted_width,:] + + bw_files = [] + chrom_set = set([]) + for file in os.listdir(bigwigs_folder): + file_path = os.path.join(bigwigs_folder, file) + try: + # Validate using pyBigTools (add specific validation if available) + bw = pybigtools.open(file_path, "r") + chrom_set = chrom_set | set(bw.chroms().keys()) + bw_files.append(file_path) + bw.close() + except ValueError: + pass + except pybigtools.BBIReadError: + pass + + + consensus_peaks = consensus_peaks.loc[consensus_peaks["chrom"].isin(chrom_set),:] + + bw_files = sorted(bw_files) + if not bw_files: + raise FileNotFoundError(f"No valid bigWig files found in '{bigwigs_folder}'") + + # Process bigWig files in parallel and collect the results + logger.info(f"Extracting values from {len(bw_files)} bigWig files...") + all_results = [] + + # Prepare obs and var for AnnData + obs_df = pd.DataFrame( + data={"file_path": bw_files}, + index=[ + os.path.basename(file).rpartition(".")[0].replace(".", "_") + for file in bw_files + ], + ) + var_df = consensus_peaks.set_index("region") + + # Create AnnData object + adata = ad.AnnData(X = csr_matrix((obs_df.shape[0],var_df.shape[0])), obs=obs_df, var=var_df) + adata.uns['params'] = {} + adata.uns['params']['target_region_width'] = target_region_width + adata.uns['params']['shifted_region_width'] = shifted_width + adata.uns['params']['max_stochastic_shift'] = max_stochastic_shift + adata.write_h5ad(h5ad_path) + + _write_raw_bigwigs_to_uns( + h5ad_filename=h5ad_path, + consensus_peaks=consensus_peaks, + bigwig_files=bw_files, + track_key="tracks", + chunk_size=chunk_size, + ) + + lazy_adata = read_lazy_h5ad(filename=h5ad_path, mode="r+" , track_key="tracks")#, lazy_varp_keys=[varp_keys] + return lazy_adata + + + + + diff --git a/src/crested/tl/data/__init__.py b/src/crested/tl/data/__init__.py index f9200d8e..6f485cf3 100644 --- a/src/crested/tl/data/__init__.py +++ b/src/crested/tl/data/__init__.py @@ -1,5 +1,5 @@ """Init file for data module.""" -from ._anndatamodule import AnnDataModule +from ._anndatamodule import AnnDataModule, MetaAnnDataModule from ._dataloader import AnnDataLoader -from ._dataset import AnnDataset, SequenceLoader +from ._dataset import AnnDataset, SequenceLoader, MetaAnnDataset diff --git a/src/crested/tl/data/_anndatamodule.py b/src/crested/tl/data/_anndatamodule.py index 924af060..b46cf48a 100644 --- a/src/crested/tl/data/_anndatamodule.py +++ b/src/crested/tl/data/_anndatamodule.py @@ -3,11 +3,53 @@ from __future__ import annotations from os import PathLike +from torch.utils.data import Sampler +import numpy as np from crested._genome import Genome, _resolve_genome +from anndata import AnnData from ._dataloader import AnnDataLoader -from ._dataset import AnnDataset +from ._dataset import AnnDataset, MetaAnnDataset + + +def set_stage_sample_probs(adata, stage: str): + """ + If stage == 'train', then all regions with adata.var['split'] == 'train' + get sample_probs = adata.var['train_probs'], else 0. + + If stage == 'test', similarly do sample_probs = adata.var['test_probs'], else 0. + + If stage == 'val', do sample_probs = adata.var['val_probs'], else 0. + """ + # Validate we have columns + required_cols = ["split"] + for c in required_cols: + if c not in adata.var: + raise KeyError(f"Missing column {c} in adata.var") + + sample_probs = np.zeros(adata.n_vars, dtype=float) + + if stage == "train": + mask = (adata.var["split"] == "train") + if "train_probs" not in adata.var: + adata.var["train_probs"] = 1. + adata.var["train_probs"] = adata.var["train_probs"]/adata.var["train_probs"].sum() + sample_probs[mask] = adata.var["train_probs"][mask].values + adata.var["sample_probs"] = sample_probs + adata.var["sample_probs"] = adata.var["sample_probs"]/adata.var["sample_probs"].sum() + mask = (adata.var["split"] == "val") + adata.var["val_probs"] = mask.astype(float) + adata.var["val_probs"] = adata.var["val_probs"]/adata.var["val_probs"].sum() + elif stage == "test": + mask = (adata.var["split"] == "test") + adata.var["test_probs"] = mask.astype(float) + adata.var["test_probs"] = adata.var["test_probs"]/adata.var["test_probs"].sum() + elif stage == "predict": + adata.var["predict_probs"] = 1. + adata.var["predict_probs"] = adata.var["predict_probs"]/adata.var["predict_probs"].sum() + else: + print("Invalid stage, sample probabilites unchanged") class AnnDataModule: @@ -65,16 +107,20 @@ def __init__( genome: PathLike | Genome | None = None, chromsizes_file: PathLike | None = None, in_memory: bool = True, - always_reverse_complement=True, + always_reverse_complement: bool = True, random_reverse_complement: bool = False, max_stochastic_shift: int = 0, deterministic_shift: bool = False, shuffle: bool = True, batch_size: int = 256, + data_sources: dict[str, str] = {'y':'X'}, + obs_columns: list[str] | None = None, + obsm_keys: list[str] | None = None, + varp_keys: list[str] | None = None, ): """Initialize the DataModule with the provided dataset and options.""" self.adata = adata - self.genome = _resolve_genome(genome, chromsizes_file) # backward compatibility + self.genome = _resolve_genome(genome, chromsizes_file) # Function assumed available self.always_reverse_complement = always_reverse_complement self.in_memory = in_memory self.random_reverse_complement = random_reverse_complement @@ -82,6 +128,10 @@ def __init__( self.deterministic_shift = deterministic_shift self.shuffle = shuffle self.batch_size = batch_size + self.data_sources = data_sources + self.obs_columns = obs_columns + self.obsm_keys = obsm_keys + self.varp_keys = varp_keys self._validate_init_args(random_reverse_complement, always_reverse_complement) @@ -105,56 +155,64 @@ def setup(self, stage: str) -> None: Generates the train, val, test or predict dataset based on the provided stage. Should always be called before accessing the dataloaders. - Generally you don't need to call this directly, as this is called inside the `tl.Crested` trainer class. + Generally, you don't need to call this directly, as this is called inside the `tl.Crested` trainer class. Parameters ---------- stage Stage for which to setup the dataloader. Either 'fit', 'test' or 'predict'. """ + args = { + "adata": self.adata, + "genome": self.genome, + "data_sources": self.data_sources, + "in_memory": self.in_memory, + "always_reverse_complement": self.always_reverse_complement, + "random_reverse_complement": self.random_reverse_complement, + "max_stochastic_shift": self.max_stochastic_shift, + "deterministic_shift": self.deterministic_shift, + "obs_columns": self.obs_columns, + "obsm_keys": self.obsm_keys, + "varp_keys": self.varp_keys, + } + if stage == "fit": - self.train_dataset = AnnDataset( - self.adata, - self.genome, - split="train", - in_memory=self.in_memory, - always_reverse_complement=self.always_reverse_complement, - random_reverse_complement=self.random_reverse_complement, - max_stochastic_shift=self.max_stochastic_shift, - deterministic_shift=self.deterministic_shift, - ) - self.val_dataset = AnnDataset( - self.adata, - self.genome, - split="val", - in_memory=self.in_memory, - always_reverse_complement=False, - random_reverse_complement=False, - max_stochastic_shift=0, - ) + # Training dataset + train_args = args.copy() + train_args["split"] = "train" + set_stage_sample_probs(self.adata, "train") + self.train_dataset = AnnDataset(**train_args) + val_args = args.copy() + val_args["split"] = "val" + val_args["always_reverse_complement"] = False + val_args["random_reverse_complement"] = False + val_args["max_stochastic_shift"] = 0 + self.val_dataset = AnnDataset(**val_args) + elif stage == "test": - self.test_dataset = AnnDataset( - self.adata, - self.genome, - split="test", - in_memory=False, - always_reverse_complement=False, - random_reverse_complement=False, - max_stochastic_shift=0, - ) + test_args = args.copy() + test_args["split"] = "test" + test_args["in_memory"] = False + test_args["always_reverse_complement"] = False + test_args["random_reverse_complement"] = False + test_args["max_stochastic_shift"] = 0 + set_stage_sample_probs(self.adata, "test") + self.test_dataset = AnnDataset(**test_args) + elif stage == "predict": - self.predict_dataset = AnnDataset( - self.adata, - self.genome, - split=None, - in_memory=False, - always_reverse_complement=False, - random_reverse_complement=False, - max_stochastic_shift=0, - ) + predict_args = args.copy() + predict_args["split"] = None + predict_args["in_memory"] = False + predict_args["always_reverse_complement"] = False + predict_args["random_reverse_complement"] = False + predict_args["max_stochastic_shift"] = 0 + set_stage_sample_probs(self.adata, stage="predict") + self.predict_dataset = AnnDataset(**predict_args) + else: raise ValueError(f"Invalid stage: {stage}") + @property def train_dataloader(self): """:obj:`crested.tl.data.AnnDataLoader`: Training dataloader.""" @@ -165,6 +223,7 @@ def train_dataloader(self): batch_size=self.batch_size, shuffle=self.shuffle, drop_remainder=False, + stage='train' ) @property @@ -177,6 +236,7 @@ def val_dataloader(self): batch_size=self.batch_size, shuffle=False, drop_remainder=False, + stage='val' ) @property @@ -189,6 +249,7 @@ def test_dataloader(self): batch_size=self.batch_size, shuffle=False, drop_remainder=False, + stage='test' ) @property @@ -201,6 +262,7 @@ def predict_dataloader(self): batch_size=self.batch_size, shuffle=False, drop_remainder=False, + stage='predict' ) def __repr__(self): @@ -213,3 +275,273 @@ def __repr__(self): f"max_stochastic_shift={self.max_stochastic_shift}, shuffle={self.shuffle}, " f"batch_size={self.batch_size}" ) + + +class MetaSampler(Sampler): + """ + A Sampler that yields indices in proportion to meta_dataset.global_probs. + """ + + def __init__(self, meta_dataset: MetaAnnDataset, epoch_size: int = 100_000): + """ + Parameters + ---------- + meta_dataset : MetaAnnDataset + The combined dataset with global_indices and global_probs. + epoch_size : int + How many samples we consider in one epoch of training. + """ + super().__init__(data_source=meta_dataset) + self.meta_dataset = meta_dataset + self.epoch_size = epoch_size + + # Check that sum of global_probs ~ 1.0 + s = self.meta_dataset.global_probs.sum() + if not np.isclose(s, 1.0, atol=1e-6): + raise ValueError( + "global_probs do not sum to 1 after final normalization. sum = {}".format(s) + ) + + def __iter__(self): + """ + For each epoch, yield 'epoch_size' random draws from + [0..len(meta_dataset)-1], weighted by global_probs. + """ + n = len(self.meta_dataset) + p = self.meta_dataset.global_probs + for _ in qrange(self.epoch_size): + yield np.random.choice(n, p=p) + + def __len__(self): + """ + The DataLoader uses len(sampler) to figure out how many samples per epoch. + """ + return self.epoch_size + +class MetaAnnDataModule: + """ + A DataModule for multiple AnnData objects (one per species), + merging them into a single MetaAnnDataset for each stage. + Then we rely on the MetaSampler to do globally weighted sampling. + + We do NOT physically reindex the obs dimension. Instead, each AnnData + may have a different set of obs_names. The code which loads coverage or X + at the dataset level is expected to handle label-based indexing and fill + missing rows with NaN as needed. + """ + + def __init__( + self, + adatas: list[AnnData], + genomes: list[Genome], + data_sources: dict[str, str] = {'y':'X'}, + in_memory: bool = True, + + random_reverse_complement: bool = True, + max_stochastic_shift: int = 0, + deterministic_shift: bool = False, + shuffle: bool = True, + batch_size: int = 256, + obs_columns: list[str] | None = None, + obsm_keys: list[str] | None = None, + varp_keys: list[str] | None = None, + epoch_size: int = 100_000, + ): + """ + Parameters + ---------- + adatas : list[AnnData] + Each species or dataset is stored in its own AnnData. + genomes : list[Genome] + Matching list of genome references for each AnnData. + in_memory : bool + If True, sequences might be loaded into memory in each AnnDataset. + random_reverse_complement : bool + If True, we randomly reverse complement each region. (always_reverse complement not available for MetaAnnDataModule) + max_stochastic_shift : int + Maximum shift (±) to apply to each region for data augmentation. + deterministic_shift : bool + If True, do the older style shifting in fixed strides. + shuffle : bool + Whether to shuffle the dataset in the dataloader. + batch_size : int + How many samples per batch. + obs_columns : list[str] + Any obs columns from each AnnData to replicate in the dataset item. + obsm_keys : list[str] + Any obsm keys from each AnnData to replicate. + varp_keys : list[str] + Any varp keys from each AnnData to replicate. + epoch_size : int + How many samples per epoch for the DataLoader if using a custom sampler. + """ + if len(adatas) != len(genomes): + raise ValueError("Must provide as many `adatas` as `genomes`.") + + self.adatas = adatas + self.genomes = genomes + self.in_memory = in_memory + self.always_reverse_complement = False + self.random_reverse_complement = random_reverse_complement + self.max_stochastic_shift = max_stochastic_shift + self.deterministic_shift = deterministic_shift + self.shuffle = shuffle + self.batch_size = batch_size + self.data_sources = data_sources + self.obs_columns = obs_columns + self.obsm_keys = obsm_keys + self.varp_keys = varp_keys + self.epoch_size = epoch_size + + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.predict_dataset = None + + self.meta_obs_names = np.array(set().union(*[adata.obs_names for adata in self.adatas])) + for adata in self.adatas: + adata.meta_obs_names = self.meta_obs_names + + def setup(self, stage: str) -> None: + """ + Create the AnnDataset objects for each adata+genome, then unify them + into a MetaAnnDataset for the given stage. + + Unlike older code, we do NOT reindex each AnnData's obs dimension. + Instead, each AnnDataset can handle label-based indexing or + fill missing rows with NaN via lazy structures. + """ + def dataset_args(split): + return { + "in_memory": self.in_memory, + "data_sources": self.data_sources, + "always_reverse_complement": self.always_reverse_complement, + "random_reverse_complement": self.random_reverse_complement, + "max_stochastic_shift": self.max_stochastic_shift, + "deterministic_shift": self.deterministic_shift, + "obs_columns": self.obs_columns, + "obsm_keys": self.obsm_keys, + "varp_keys": self.varp_keys, + "split": split, + } + + if stage == "fit": + train_datasets = [] + val_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + # Training + args = dataset_args("train") + set_stage_sample_probs(adata, "train") + ds_train = AnnDataset(adata=adata, genome=genome, **args) + train_datasets.append(ds_train) + + # Validation (no shifting, no RC) + val_args = dataset_args("val") + val_args["always_reverse_complement"] = False + val_args["random_reverse_complement"] = False + val_args["max_stochastic_shift"] = 0 + ds_val = AnnDataset(adata=adata, genome=genome, **val_args) + val_datasets.append(ds_val) + + # Merge them with MetaAnnDataset + self.train_dataset = MetaAnnDataset(train_datasets) + self.val_dataset = MetaAnnDataset(val_datasets) + + elif stage == "test": + test_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + args = dataset_args("test") + set_stage_sample_probs(adata, "test") + args["in_memory"] = False + args["always_reverse_complement"] = False + args["random_reverse_complement"] = False + args["max_stochastic_shift"] = 0 + + ds_test = AnnDataset(adata=adata, genome=genome, **args) + test_datasets.append(ds_test) + + self.test_dataset = MetaAnnDataset(test_datasets) + + elif stage == "predict": + predict_datasets = [] + for adata, genome in zip(self.adatas, self.genomes): + args = dataset_args(None) + set_stage_sample_probs(adata, "predict") + args["in_memory"] = False + args["always_reverse_complement"] = False + args["random_reverse_complement"] = False + args["max_stochastic_shift"] = 0 + + ds_pred = AnnDataset(adata=adata, genome=genome, **args) + predict_datasets.append(ds_pred) + + self.predict_dataset = MetaAnnDataset(predict_datasets) + + else: + raise ValueError(f"Invalid stage: {stage}") + + @property + def train_dataloader(self): + if self.train_dataset is None: + raise ValueError("train_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_remainder=False, + epoch_size=self.epoch_size, + stage='train' + ) + + @property + def val_dataloader(self): + if self.val_dataset is None: + raise ValueError("val_dataset is not set. Run setup('fit') first.") + return AnnDataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + stage='val' + ) + + @property + def test_dataloader(self): + if self.test_dataset is None: + raise ValueError("test_dataset is not set. Run setup('test') first.") + return AnnDataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + stage='test' + ) + + @property + def predict_dataloader(self): + if self.predict_dataset is None: + raise ValueError("predict_dataset is not set. Run setup('predict') first.") + return AnnDataLoader( + self.predict_dataset, + batch_size=self.batch_size, + shuffle=False, + drop_remainder=False, + epoch_size=self.epoch_size, + stage='predict' + ) + + def __repr__(self): + return ( + f"MetaAnnDataModule(" + f"num_species={len(self.adatas)}, " + f"batch_size={self.batch_size}, shuffle={self.shuffle}, " + f"max_stochastic_shift={self.max_stochastic_shift}, " + f"random_reverse_complement={self.random_reverse_complement}, " + f"always_reverse_complement={self.always_reverse_complement}, " + f"in_memory={self.in_memory}, " + f"deterministic_shift={self.deterministic_shift}, " + f"epoch_size={self.epoch_size}" + f")" + ) diff --git a/src/crested/tl/data/_dataloader.py b/src/crested/tl/data/_dataloader.py index 8a756a88..1f51e54f 100644 --- a/src/crested/tl/data/_dataloader.py +++ b/src/crested/tl/data/_dataloader.py @@ -3,16 +3,142 @@ from __future__ import annotations import os +from collections import defaultdict +from ._dataset import AnnDataset, MetaAnnDataset if os.environ["KERAS_BACKEND"] == "torch": import torch - from torch.utils.data import DataLoader + from torch.utils.data import DataLoader, Sampler else: import tensorflow as tf from ._dataset import AnnDataset +from torch.utils.data import Sampler +import numpy as np + +class WeightedRegionSampler(Sampler): + def __init__(self, dataset, epoch_size=100_000): + super().__init__(data_source=dataset) + self.dataset = dataset + self.epoch_size = epoch_size + p = dataset.augmented_probs + s = p.sum() + if s <= 0: + raise ValueError("All sample_probs are zero, cannot sample.") + self.probs = p / s + + def __iter__(self): + n = len(self.dataset.index_manager.augmented_indices) + for _ in range(self.epoch_size): + yield np.random.choice(n, p=self.probs) + + def __len__(self): + return self.epoch_size + +class NonShuffleRegionSampler(Sampler): + """ + Enumerate each region with sample_probs>0 exactly once, in a deterministic order. + """ + + def __init__(self, dataset): + super().__init__(data_source=dataset) + self.dataset = dataset + + # We get the augmented_probs from dataset.augmented_probs + # We'll filter out any zero-prob entries + p = self.dataset.augmented_probs + self.nonzero_indices = np.flatnonzero(p > 0.0) # e.g. [0,1,5,...] + if len(self.nonzero_indices) == 0: + raise ValueError("No nonzero probabilities for val/test stage.") + + def __iter__(self): + # Return each index once, in ascending order + # or sort by some custom logic + return iter(self.nonzero_indices) + + def __len__(self): + # The DataLoader sees how many samples in an epoch + return len(self.nonzero_indices) + +class MetaSampler(Sampler): + """ + A Sampler that yields indices in proportion to meta_dataset.global_probs. + """ + + def __init__(self, meta_dataset: MetaAnnDataset, epoch_size: int = 100_000): + """ + Parameters + ---------- + meta_dataset : MetaAnnDataset + The combined dataset with global_indices and global_probs. + epoch_size : int + How many samples we consider in one epoch of training. + """ + super().__init__(data_source=meta_dataset) + self.meta_dataset = meta_dataset + self.epoch_size = epoch_size + + # Check that sum of global_probs ~ 1.0 + s = self.meta_dataset.global_probs.sum() + if not np.isclose(s, 1.0, atol=1e-6): + raise ValueError( + "global_probs do not sum to 1 after final normalization. sum = {}".format(s) + ) + + def __iter__(self): + """ + For each epoch, yield 'epoch_size' random draws from + [0..len(meta_dataset)-1], weighted by global_probs. + """ + n = len(self.meta_dataset) + p = self.meta_dataset.global_probs + for _ in range(self.epoch_size): + yield np.random.choice(n, p=p) + + def __len__(self): + """ + The DataLoader uses len(sampler) to figure out how many samples per epoch. + """ + return self.epoch_size + +class NonShuffleMetaSampler(Sampler): + """ + A Sampler for MetaAnnDataset that enumerates all indices + with nonzero global_probs exactly once, in ascending order. + + Typically used for val/test phases, ensuring deterministic + coverage of all relevant entries. + """ + + def __init__(self, meta_dataset, sort=True): + """ + Parameters + ---------- + meta_dataset : MetaAnnDataset + The combined dataset with .global_indices and .global_probs. + sort : bool + If True, sort the nonzero indices ascending. If False, keep them in + the existing order. You can also implement your own custom ordering. + """ + super().__init__(data_source=meta_dataset) + self.meta_dataset = meta_dataset + + # We'll gather the set of global indices with probability > 0 + p = self.meta_dataset.global_probs + self.nonzero_global_indices = np.flatnonzero(p > 0) + if sort: + self.nonzero_global_indices.sort() + + def __iter__(self): + # yields each global index exactly once + return iter(self.nonzero_global_indices) + + def __len__(self): + return len(self.nonzero_global_indices) + + class AnnDataLoader: """ Pytorch-like DataLoader class for AnnDataset with options for batching, shuffling, and one-hot encoding. @@ -38,43 +164,87 @@ class AnnDataLoader: >>> for x, y in dataloader.data: ... # Your training loop here """ - def __init__( self, - dataset: AnnDataset, + dataset, # can be AnnDataset or MetaAnnDataset batch_size: int, shuffle: bool = False, drop_remainder: bool = True, + epoch_size: int = 100_000, + stage: str = "train", # if you want train/val/test logic ): - """Initialize the DataLoader with the provided dataset and options.""" self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_remainder = drop_remainder - if os.environ["KERAS_BACKEND"] == "torch": + self.epoch_size = epoch_size + self.stage = stage + + if os.environ.get("KERAS_BACKEND", "") == "torch": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = None + + self.sampler = None - if self.shuffle: - self.dataset.shuffle = True + # 1) If it's a MetaAnnDataset => use MetaSampler or NonShuffleMetaSampler + if isinstance(dataset, MetaAnnDataset): + if self.stage == "train": + self.sampler = MetaSampler(dataset, epoch_size=self.epoch_size) + else: + # For val/test => enumerates all nonzero-prob entries once + self.sampler = NonShuffleMetaSampler(dataset, sort=True) + + # 2) If it's a single AnnDataset => check for augmented_probs + else: + # Single AnnDataset => WeightedRegionSampler or NonShuffleRegionSampler + if getattr(dataset, "augmented_probs", None) is not None: + if self.stage == "train": + self.sampler = WeightedRegionSampler(dataset, epoch_size=self.epoch_size) + else: + # e.g. val/test => enumerates nonzero-prob entries once + self.sampler = NonShuffleRegionSampler(dataset) + else: + # No probabilities => uniform or user logic + if self.shuffle and hasattr(self.dataset, "shuffle"): + self.dataset.shuffle = True def _collate_fn(self, batch): - """Collate function to move tensors to the specified device if backend is torch.""" - inputs, targets = zip(*batch) - inputs = torch.stack([torch.tensor(input) for input in inputs]).to(self.device) - targets = torch.stack([torch.tensor(target) for target in targets]).to( - self.device - ) - return inputs, targets + """ + Collate function to gather list of sample-dicts into a single batched dict of tensors. + """ + x = defaultdict(list) + for sample_dict in batch: + for key, val in sample_dict.items(): + x[key].append(torch.tensor(val, dtype=torch.float32)) + + # Stack and move to device + for key in x.keys(): + x[key] = torch.stack(x[key], dim=0) + if self.device is not None: + x[key] = x[key].to(self.device) + return x def _create_dataset(self): - if os.environ["KERAS_BACKEND"] == "torch": - return DataLoader( - self.dataset, - batch_size=self.batch_size, - drop_last=self.drop_remainder, - num_workers=0, - collate_fn=self._collate_fn, - ) + if os.environ.get("KERAS_BACKEND", "") == "torch": + if self.sampler is not None: + return DataLoader( + self.dataset, + batch_size=self.batch_size, + sampler=self.sampler, + drop_last=self.drop_remainder, + num_workers=0, + collate_fn=self._collate_fn, + ) + else: + return DataLoader( + self.dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + drop_last=self.drop_remainder, + num_workers=0, + collate_fn=self._collate_fn, + ) elif os.environ["KERAS_BACKEND"] == "tensorflow": ds = tf.data.Dataset.from_generator( self.dataset, @@ -97,7 +267,10 @@ def data(self): def __len__(self): """Return the number of batches in the DataLoader based on the dataset size and batch size.""" - return (len(self.dataset) + self.batch_size - 1) // self.batch_size + if self.sampler is not None: + return (self.epoch_size + self.batch_size - 1) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size def __repr__(self): """Return the string representation of the DataLoader.""" diff --git a/src/crested/tl/data/_dataset.py b/src/crested/tl/data/_dataset.py index 123af489..62755b79 100644 --- a/src/crested/tl/data/_dataset.py +++ b/src/crested/tl/data/_dataset.py @@ -10,6 +10,7 @@ from loguru import logger from scipy.sparse import spmatrix from tqdm import tqdm +import pandas as pd from crested._genome import Genome from crested.utils import one_hot_encode_sequence @@ -273,7 +274,6 @@ def _augment_indices(self, indices: list[str]) -> tuple[list[str], dict[str, str augmented_indices_map[_flip_region_strand(stranded_region)] = region return augmented_indices, augmented_indices_map - if os.environ["KERAS_BACKEND"] == "pytorch": import torch @@ -307,11 +307,17 @@ class AnnDataset(BaseClass): deterministic_shift If true, each region will be shifted twice with stride 50bp to each side. This is our legacy shifting, we recommend using max_stochastic_shift instead. + obs_columns + Columns in obs that will be added to the dataset. + obsm_columns + Keys in obsm that will be added to the dataset. + varp_columns + Keys in varp that will be added to the dataset. """ def __init__( self, - anndata: AnnData, + adata: AnnData, genome: Genome, split: str = None, in_memory: bool = True, @@ -319,19 +325,63 @@ def __init__( always_reverse_complement: bool = False, max_stochastic_shift: int = 0, deterministic_shift: bool = False, + data_sources: dict[str, str] = {'y':'X'}, #default to old approach + obs_columns: list[str] | None = None, # multiple obs columns + obsm_keys: list[str] | None = None, # multiple obsm keys + varp_keys: list[str] | None = None, # multiple varp keys + ): + """Initialize the dataset with the provided AnnData object and options.""" - self.anndata = self._split_anndata(anndata, split) + self.adata = self._split_anndata(adata, split) self.split = split - self.indices = list(self.anndata.var_names) + self.indices = list(self.adata.var_names) self.in_memory = in_memory - self.compressed = isinstance(self.anndata.X, spmatrix) + self.compressed = isinstance(self.adata.X, spmatrix) self.index_map = {index: i for i, index in enumerate(self.indices)} - self.num_outputs = self.anndata.X.shape[0] + self.num_outputs = self.adata.X.shape[0] self.random_reverse_complement = random_reverse_complement self.max_stochastic_shift = max_stochastic_shift + self.meta_obs_names = np.array(self.adata.obs_names) self.shuffle = False # managed by wrapping class AnnDataLoader - + self.obs_columns = obs_columns if obs_columns is not None else [] + self.obsm_keys = obsm_keys if obsm_keys is not None else [] + self.varp_keys = varp_keys if varp_keys is not None else [] + self.data_sources = data_sources + self.region_width = adata.uns['params']['target_region_width'] if 'target_region_width' in adata.uns['params'].keys() else int(np.round(np.mean(adata.var['end'] - adata.var['start']))) - (2*self.max_stochastic_shift) + + # Validate and store obs data + self.obs_data = {} + for col in self.obs_columns: + if col not in adata.obs: + raise ValueError(f"obs column '{col}' not found.") + # Convert categorical to integer codes if needed + if pd.api.types.is_categorical_dtype(adata.obs[col]): + self.obs_data[col] = adata.obs[col].cat.codes.values + else: + self.obs_data[col] = adata.obs[col].values + + # Validate and store obsm data + self.obsm_data = {} + for key in self.obsm_keys: + if key not in adata.obsm: + raise ValueError(f"obsm key '{key}' not found.") + mat = adata.obsm[key] + if mat.shape[0] != adata.n_obs: + raise ValueError(f"Dimension mismatch for obsm key '{key}'.") + self.obsm_data[key] = mat + + # Validate and store varp data + self.varp_data = {} + for key in self.varp_keys: + if key not in adata.varp: + raise ValueError(f"varp key '{key}' not found.") + mat = adata.varp[key] + if mat.shape[0] != adata.n_var: + raise ValueError(f"Dimension mismatch for varp key '{key}'.") + self.varp_data[key] = mat + + # Check region formatting stranded = _check_strandedness(self.indices[0]) if stranded and (always_reverse_complement or random_reverse_complement): @@ -356,61 +406,125 @@ def __init__( self.seq_len = len( self.sequence_loader.get_sequence(self.indices[0], stranded=stranded) ) - + + self.augmented_probs = None + if self.split == 'train': + probs = adata.var["train_probs"].values.astype(float) + elif self.split == 'val': + probs = adata.var["val_probs"].values.astype(float) + elif self.split == 'test': + probs = adata.var["test_probs"].values.astype(float) + elif self.split == 'predict': + probs = adata.var["predict_probs"].values.astype(float) + else: + self.augmented_probs = np.ones(adata.shape[1]) + self.augmented_probs = self.augmented_probs/self.augmented_probs.sum() + return + probs = np.clip(probs, 0, None) + + n_aug = len(self.index_manager.augmented_indices) + self.augmented_probs = np.ones(n_aug, dtype=float) + self.augmented_probs /= self.augmented_probs.sum() + + for i, aug_region in enumerate(self.index_manager.augmented_indices): + original_region = self.index_manager.augmented_indices_map[aug_region] + var_idx = self.index_map[original_region] + self.augmented_probs[i] = probs[var_idx] + + @staticmethod - def _split_anndata(anndata: AnnData, split: str) -> AnnData: - """Return subset of anndata based on a given split column.""" + def _split_anndata(adata: AnnData, split: str) -> AnnData: + """ + For backward compatibility. Skip physically subsetting for train/val/test. + """ if split: - if "split" not in anndata.var.columns: + if "split" not in adata.var.columns: raise KeyError( - "No split column found in anndata.var. Run `pp.train_val_test_split` first." + "No split column found in adata.var. Run `pp.train_val_test_split` first." ) - subset = ( - anndata[:, anndata.var["split"] == split].copy() - if split - else anndata.copy() - ) - return subset + return adata def __len__(self) -> int: """Get number of (augmented) samples in the dataset.""" return len(self.index_manager.augmented_indices) - def _get_target(self, index: str) -> np.ndarray: - """Get target for a given index.""" - y_index = self.index_map[index] - return ( - self.anndata.X[:, y_index].toarray().flatten() - if self.compressed - else self.anndata.X[:, y_index] - ) + def _get_data_array(self, source_str: str, varname: str, shift: int = 0) -> np.ndarray: + """ + Retrieve data from anndata, given a source string that can be: + - "X" => from self.adata.X + - "layers/" => from self.adata.layers[] + - "varp/" => from self.adata.varp[] + - ... or other expansions + + varname: the name of the var, e.g. "chr1:100-200" + shift: an int to align coverage with the same offset used for DNA + """ + var_i = self.index_map[varname] + + # 2) parse source_str + if source_str == "X": + if self.compressed: + arr = self.adata.X[:, var_i].toarray().flatten() + else: + arr = self.adata.X[:, var_i] + return arr + + elif source_str.startswith("layers/"): + key = source_str.split("/",1)[1] # e.g. "tracks" + coverage_3d = self.adata.layers[key] + start_idx = self.max_stochastic_shift + shift + end_idx = start_idx + self.region_width + arr = coverage_3d[self.meta_obs_names, var_i][...,start_idx:end_idx] + return np.asarray(arr) + elif source_str.startswith("varp/"): + key = source_str.split("/",1)[1] + mat = self.varp_data[key] + row = mat[var_i] + return np.asarray(row) + else: + raise ValueError(f"Unknown data source {source_str}.") + - def __getitem__(self, idx: int) -> tuple[str, np.ndarray]: - """Return sequence and target for a given index.""" + def __getitem__(self, idx: int) -> dict: + """ + Returns a dictionary that might contain: + - "sequence": the one-hot DNA + - plus any # of keys from self.data_sources + """ + # 1) pick region augmented_index = self.index_manager.augmented_indices[idx] original_index = self.index_manager.augmented_indices_map[augmented_index] - # stochastic shift + + # 2) pick the random shift, so that DNA and track remain aligned + shift = 0 if self.max_stochastic_shift > 0: - shift = np.random.randint( - -self.max_stochastic_shift, self.max_stochastic_shift + 1 - ) - else: - shift = 0 - - # Get sequence - x = self.sequence_loader.get_sequence( - augmented_index, stranded=True, shift=shift - ) - - # random reverse complement (always_reverse_complement is done in the sequence loader) + shift = np.random.randint(-self.max_stochastic_shift, self.max_stochastic_shift + 1) + + # 3) get DNA sequence + x_seq = self.sequence_loader.get_sequence(augmented_index, stranded=True, shift=shift) if self.random_reverse_complement and np.random.rand() < 0.5: - x = self.sequence_loader._reverse_complement(x) - - # one hot encode sequence and convert to numpy array - x = one_hot_encode_sequence(x, expand_dim=False) - y = self._get_target(original_index) - - return x, y + x_seq = self.sequence_loader._reverse_complement(x_seq) + # x_seq = one_hot_encode_sequence(x_seq, expand_dim=False) + x_seq = one_hot_encode_sequence(x_seq, expand_dim=False) + + + item = { + "sequence": x_seq, + } + + original_varname = original_index # e.g. "chr1:100-200" + for name, source_str in self.data_sources.items(): + if name == "sequence": + continue + arr = self._get_data_array(source_str, original_varname, shift=shift) + item[name] = arr + + for col in self.obs_columns: + item[col] = self.obs_data[col] + for key in self.obsm_keys: + item[key] = self.obsm_data[key] + + return item def __call__(self): """Call generator for the dataset.""" @@ -422,4 +536,81 @@ def __call__(self): def __repr__(self) -> str: """Get string representation of the dataset.""" - return f"AnnDataset(anndata_shape={self.anndata.shape}, n_samples={len(self)}, num_outputs={self.num_outputs}, split={self.split}, in_memory={self.in_memory})" + return f"AnnDataset(anndata_shape={self.adata.shape}, n_samples={len(self)}, num_outputs={self.num_outputs}, split={self.split}, in_memory={self.in_memory})" + +class MetaAnnDataset: + """ + Combines multiple AnnDataset objects into a single dataset, + merging all their (augmented_index, probability) pairs into one global list. + + We do a final normalization across all sub-datasets so that + sample_prob from each dataset is treated as an unnormalized weight. + """ + + def __init__(self, datasets: list[AnnDataset]): + """ + Parameters + ---------- + datasets : list of AnnDataset + Each AnnDataset is for a different species or annotation set. + """ + if not datasets: + raise ValueError("No AnnDataset provided to MetaAnnDataset.") + + self.datasets = datasets + self.always_reverse_complement = False + # global_indices will store tuples of (dataset_idx, local_idx) + # global_probs will store the merged, unnormalized probabilities + self.global_indices = [] + self.global_probs = [] + + for ds_idx, ds in enumerate(datasets): + ds_len = len(ds.index_manager.augmented_indices) + if ds_len == 0: + continue + + # If the dataset has augmented_probs, we use them as unnormalized weights + # If not, fallback to 1.0 for each region + if ds.augmented_probs is not None: + for local_i in range(ds_len): + self.global_indices.append((ds_idx, local_i)) + self.global_probs.append(ds.augmented_probs[local_i]) + else: + for local_i in range(ds_len): + self.global_indices.append((ds_idx, local_i)) + self.global_probs.append(1.0) + + self.global_indices = np.array(self.global_indices, dtype=object) + self.global_probs = np.array(self.global_probs, dtype=float) + + # Normalize across the entire set + total = self.global_probs.sum() + if total > 0: + self.global_probs /= total + else: + # fallback: uniform if everything is zero + n = len(self.global_probs) + if n > 0: + self.global_probs.fill(1.0 / n) + + def __len__(self): + """ + The total number of augmented indices across all sub-datasets. + """ + return len(self.global_indices) + + def __getitem__(self, global_idx: int): + """ + A DataLoader or sampler will pass a global_idx in [0..len(self)-1]. + We map that to (dataset_idx, local_i) and call the sub-dataset's __getitem__. + """ + ds_idx, local_i = self.global_indices[global_idx] + ds_idx = int(ds_idx) + local_i = int(local_i) + return self.datasets[ds_idx][local_i] + + def __repr__(self): + return (f"MetaAnnDataset(num_datasets={len(self.datasets)}, " + f"total_augmented_indices={len(self.global_indices)})") + +