diff --git a/kipoiseq/extractors/vcf_seq.py b/kipoiseq/extractors/vcf_seq.py index 5879aa8..5e6981c 100644 --- a/kipoiseq/extractors/vcf_seq.py +++ b/kipoiseq/extractors/vcf_seq.py @@ -2,7 +2,8 @@ from typing import Union from pyfaidx import Sequence, complement -from kipoiseq.dataclasses import Interval +from cyvcf2 import VCF +from kipoiseq.dataclasses import Interval, Variant from kipoiseq.extractors import ( BaseExtractor, FastaStringExtractor, @@ -15,7 +16,8 @@ __all__ = [ 'VariantSeqExtractor', 'SingleVariantVCFSeqExtractor', - 'SingleSeqVCFSeqExtractor' + 'SingleSeqVCFSeqExtractor', + 'SampleSeqExtractor' ] @@ -320,3 +322,88 @@ def extract(self, interval, anchor=None, sample_id=None, fixed_len=True): anchor=anchor, fixed_len=fixed_len ) + + +class SampleSeqExtractor(VariantSeqExtractor): + def __init__(self, fasta_file, vcf_file): + """Sequence extractor which can extract an alternate sequence for a + given interval and the variants corresponding to a given + sample and phase. + + Args: + fasta_file: Path to the fasta file containing the reference + sequence (can be gzipped) + vcf_file: Path to the VCF file containing phased genotype information + + """ + self.vcf = VCF(vcf_file) + self._sample_indices = dict(zip(self.vcf.samples, + range(len(self.vcf.samples)))) + + super().__init__(fasta_file) + + def extract(self, interval, sample, phase, anchor, + fixed_len=True, **kwargs): + """Extracts an alternate sequence for a given interval and the + variants corresponding to a given sample. + + Args: + interval: `kipoiseq.dataclasses.Interval`, Region of + interest from which to query the sequence. 0-based. + sample: `str`, Sample from the VCF file for which variants should be + extracted. + phase: `0` or `1`, Phase for which sequence should be extracted + anchor: `int`, Absolution position w.r.t. the interval + start. (0-based). E.g. for an interval of `chr1:10-20` + the anchor of 10 denotes the point chr1:10 in the 0-based + coordinate system. + fixed_len: `bool`, If True, the return sequence will have the + same length as the `interval` (e.g. `interval.end - + interval.start`) + kwargs: Additional keyword arguments to pass to + `SampleSeqExtractor.extract` + + Returns: + A single sequence (`str`) with all the variants applied. + """ + variants = [] + if sample is not None: + if sample not in self.vcf.samples: + raise ValueError(f'Sample \'{sample}\' ' + 'not present in VCF file') + + if phase not in (0, 1): + raise ValueError('phase argument must be in (0, 1)' + ' if sample is not None') + + # Interval is 0-based, cyvcf2 positions are 1-based: need to add 1 + variants = self._get_sample_variants( + self.vcf( + f'{interval.chrom}:' + + f'{interval.start + 1}-{interval.end + 1}' + ), + sample, + phase + ) + + return super(SampleSeqExtractor, self).extract( + interval, variants, anchor, fixed_len, **kwargs) + + def _get_sample_variants(self, variants, sample, phase): + """Given a list of `cyvcf2.Variant`, returns all those present for a + given sample and phase and converts them to + `kipoiseq.dataclasses.Variant` + + Args: + variants: List of `cyvcf2.Variant`, Variants of interest + sample: `str`, Sample for which to filter genotypes + phase: `0` or `1`, Phase for which to filter genotypes + + Returns: + List of `kipoiseq.dataclasses.Variant` + """ + sample_index = self._sample_indices[sample] + return [ + Variant.from_cyvcf(v) for v in variants + if v.genotypes[sample_index][phase] + ]