Skip to content

Commit

Permalink
add iterate function to data classes (see #19)
Browse files Browse the repository at this point in the history
  • Loading branch information
aryarm committed Apr 13, 2022
1 parent 18f5a97 commit 47f61e6
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 6 deletions.
46 changes: 45 additions & 1 deletion haptools/data/covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def read(self, samples: list[str] = None):
super().read()
# load all info into memory
# use hook_compressed to automatically handle gz files
with hook_compressed(self.fname, mode='rt') as covars:
with hook_compressed(self.fname, mode="rt") as covars:
covar_text = reader(covars, delimiter="\t")
header = next(covar_text)
# there should at least two columns
Expand Down Expand Up @@ -107,3 +107,47 @@ def read(self, samples: list[str] = None):
self.names = tuple(header[1:])
# coerce strings to floats
self.data = np.transpose(np.array(data[1:], dtype="float64"))

def iterate(self, samples: list[str] = None) -> Iterator[dict]:
"""
Read covariates from a TSV line by line without storing anything
Parameters
----------
samples : list[str], optional
A subset of the samples from which to extract covariates
Defaults to loading covariates from all samples
Yields
------
Iterator[dict]
An iterator over each line in the file, where each line is encoded as a
dictionary containing each of the class properties
"""
with hook_compressed(self.fname, mode="rt") as covars:
covar_text = reader(covars, delimiter="\t")
header = next(covar_text)
# there should at least two columns
assert (
len(header) >= 2
), "The covariates TSV should have at least two columns."
# the first column should be called "sample"
assert header[0] == "sample", (
"The first column of the covariates TSV should contain sample IDs and"
" should be named 'sample' in the header line"
)
header = tuple(header[1:])
for covar in covar_text:
if samples is None or covar[0] in samples:
try:
yield {
"samples": covar[0],
"names": header,
"data": np.array(covar[1:], dtype="float64"),
}
except:
raise AssertionError(
"Every column in the covariates file (besides the sample"
" column) must be numeric."
)
14 changes: 14 additions & 0 deletions haptools/data/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from csv import reader
from pathlib import Path
from typing import Iterator
from abc import ABC, abstractmethod
from logging import getLogger, Logger

Expand Down Expand Up @@ -50,3 +51,16 @@ def read(self):
"""
if self.data is not None:
self.log.warning("The data has already been loaded. Overriding.")

@abstractmethod
def iterate(self) -> Iterator[dict]:
"""
Return an iterator over the raw file contents
Yields
------
Iterator[dict]
An iterator over each line in the file, where each line is encoded as a
dictionary containing each of the class properties
"""
pass
55 changes: 51 additions & 4 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Iterator

import numpy as np
from cyvcf2 import VCF, Variant
Expand Down Expand Up @@ -91,11 +92,12 @@ def read(self, region: str = None, samples: list[str] = None):
Defaults to loading genotypes from all samples
"""
super().read()
# load all info into memory
# initialize variables
vcf = VCF(str(self.fname), samples=samples)
self.samples = tuple(vcf.samples)
self.variants = []
self.data = []
# load all info into memory
for variant in vcf(region):
# save meta information about each variant
self.variants.append((variant.ID, variant.CHROM, variant.POS, variant.aaf))
Expand Down Expand Up @@ -125,6 +127,53 @@ def read(self, region: str = None, samples: list[str] = None):
# transpose the GT matrix so that samples are rows and variants are columns
self.data = self.data.transpose((1, 0, 2))

def iterate(self, region: str = None, samples: list[str] = None) -> Iterator[dict]:
"""
Read genotypes from a VCF line by line without storing anything
Parameters
----------
region : str, optional
The region from which to extract genotypes; ex: 'chr1:1234-34566' or 'chr7'
For this to work, the VCF must be indexed and the seqname must match!
Defaults to loading all genotypes
samples : list[str], optional
A subset of the samples from which to extract genotypes
Defaults to loading genotypes from all samples
Yields
------
Iterator[dict]
An iterator over each line in the file, where each line is encoded as a
dictionary containing each of the class properties
"""
vcf = VCF(str(self.fname), samples=samples)
samples = tuple(vcf.samples)
# load all info into memory
for variant in vcf(region):
record = {"samples": samples}
# save meta information about each variant
record["variants"] = np.array(
(variant.ID, variant.CHROM, variant.POS, variant.aaf),
dtype=[
("id", "U50"),
("chrom", "U10"),
("pos", np.uint),
("aaf", np.float64),
],
)
# extract the genotypes to a matrix of size 1 x p x 3
# the last dimension has three items:
# 1) presence of REF in strand one
# 2) presence of REF in strand two
# 3) whether the genotype is phased
record["data"] = np.array(variant.genotypes, dtype=np.uint8)
yield record
vcf.close()

def check_biallelic(self, discard_also=False):
"""
Check that each genotype is composed of only two alleles
Expand Down Expand Up @@ -179,9 +228,7 @@ def check_phase(self):
If any heterozgyous genotpyes are unphased
"""
if self.data.shape[2] < 3:
self.log.warning(
"Phase information has already been removed from the data"
)
self.log.warning("Phase information has already been removed from the data")
return
# check: are there any variants that are heterozygous and unphased?
unphased = (self.data[:, :, 0] ^ self.data[:, :, 1]) & (~self.data[:, :, 2])
Expand Down
30 changes: 29 additions & 1 deletion haptools/data/phenotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def read(self, samples: list[str] = None):
super().read()
# load all info into memory
# use hook_compressed to automatically handle gz files
with hook_compressed(self.fname, mode='rt') as phens:
with hook_compressed(self.fname, mode="rt") as phens:
phen_text = reader(phens, delimiter="\t")
# convert to list and subset samples if need be
if samples:
Expand All @@ -95,6 +95,34 @@ def read(self, samples: list[str] = None):
# coerce strings to floats
self.data = np.array(self.data, dtype="float64")

def iterate(self, samples: list[str] = None) -> Iterator[dict]:
"""
Read phenotypes from a TSV line by line without storing anything
Parameters
----------
samples : list[str], optional
A subset of the samples from which to extract phenotypes
Defaults to loading phenotypes from all samples
Yields
------
Iterator[dict]
An iterator over each line in the file, where each line is encoded as a
dictionary containing each of the class properties
"""
with hook_compressed(self.fname, mode="rt") as phens:
phen_text = reader(phens, delimiter="\t")
for phen in phen_text:
if samples is None or phen[0] in samples:
try:
yield {"samples": phen[0], "data": float(phen[1])}
except:
raise AssertionError(
"The second column of the TSV file must numeric."
)

def standardize(self):
"""
Standardize phenotypes so they have a mean of 0 and a stdev of 1
Expand Down
35 changes: 35 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def test_load_genotypes(caplog):
assert len(caplog.records) == 3 and caplog.records[2].levelname == "WARNING"


def test_load_genotypes_iterate(caplog):
expected = get_expected_genotypes().transpose((1, 0, 2))

# can we load the data from the VCF?
gts = Genotypes(DATADIR.joinpath("simple.vcf"))
for idx, line in enumerate(gts.iterate()):
np.testing.assert_allclose(line['data'], expected[idx])
assert line['samples'] == ("HG00096", "HG00097", "HG00099", "HG00100", "HG00101")


def test_load_genotypes_discard_multiallelic():
expected = get_expected_genotypes()

Expand Down Expand Up @@ -141,6 +151,18 @@ def test_load_phenotypes(caplog):
np.testing.assert_allclose(phens.data, expected)


def test_load_phenotypes_iterate(caplog):
# create a phenotype vector with shape: num_samples x 1
expected = np.array([1, 1, 2, 2, 0])
samples = ("HG00096", "HG00097", "HG00099", "HG00100", "HG00101")

# can we load the data from the phenotype file?
phens = Phenotypes(DATADIR.joinpath("simple.tsv"))
for idx, line in enumerate(phens.iterate()):
np.testing.assert_allclose(line['data'], expected[idx])
assert line['samples'] == samples[idx]


def test_load_phenotypes_subset():
# create a phenotype vector with shape: num_samples x 1
expected = np.array([1, 1, 2, 2, 0])
Expand Down Expand Up @@ -172,6 +194,19 @@ def test_load_covariates(caplog):
assert len(caplog.records) == 1 and caplog.records[0].levelname == "WARNING"


def test_load_covariates_iterate(caplog):
# create a covariate vector with shape: num_samples x num_covars
expected = np.array([(0, 4), (1, 20), (1, 33), (0, 15), (0, 78)])
samples = ("HG00096", "HG00097", "HG00099", "HG00100", "HG00101")

# can we load the data from the covariates file?
covars = Covariates(DATADIR.joinpath("covars.tsv"))
for idx, line in enumerate(covars.iterate()):
np.testing.assert_allclose(line['data'], expected[idx])
assert line['samples'] == samples[idx]
assert line['names'] == ("sex", "age")


def test_load_covariates_subset():
# create a covriate vector with shape: num_samples x num_covars
expected = np.array([(0, 4), (1, 20), (1, 33), (0, 15), (0, 78)])
Expand Down

0 comments on commit 47f61e6

Please sign in to comment.