Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions kipoiseq/extractors/vcf_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Union

from pyfaidx import Sequence, complement
from kipoiseq.dataclasses import Interval
from cyvcf2 import VCF
from kipoiseq.dataclasses import Interval, Variant
from kipoiseq.extractors import (
BaseExtractor,
FastaStringExtractor,
Expand All @@ -15,7 +16,8 @@
__all__ = [
'VariantSeqExtractor',
'SingleVariantVCFSeqExtractor',
'SingleSeqVCFSeqExtractor'
'SingleSeqVCFSeqExtractor',
'SampleSeqExtractor'
]


Expand Down Expand Up @@ -320,3 +322,88 @@ def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
anchor=anchor,
fixed_len=fixed_len
)


class SampleSeqExtractor(VariantSeqExtractor):
def __init__(self, fasta_file, vcf_file):
"""Sequence extractor which can extract an alternate sequence for a
given interval and the variants corresponding to a given
sample and phase.

Args:
fasta_file: Path to the fasta file containing the reference
sequence (can be gzipped)
vcf_file: Path to the VCF file containing phased genotype information

"""
self.vcf = VCF(vcf_file)
self._sample_indices = dict(zip(self.vcf.samples,
range(len(self.vcf.samples))))

super().__init__(fasta_file)

def extract(self, interval, sample, phase, anchor,
fixed_len=True, **kwargs):
"""Extracts an alternate sequence for a given interval and the
variants corresponding to a given sample.

Args:
interval: `kipoiseq.dataclasses.Interval`, Region of
interest from which to query the sequence. 0-based.
sample: `str`, Sample from the VCF file for which variants should be
extracted.
phase: `0` or `1`, Phase for which sequence should be extracted
anchor: `int`, Absolution position w.r.t. the interval
start. (0-based). E.g. for an interval of `chr1:10-20`
the anchor of 10 denotes the point chr1:10 in the 0-based
coordinate system.
fixed_len: `bool`, If True, the return sequence will have the
same length as the `interval` (e.g. `interval.end -
interval.start`)
kwargs: Additional keyword arguments to pass to
`SampleSeqExtractor.extract`

Returns:
A single sequence (`str`) with all the variants applied.
"""
variants = []
if sample is not None:
if sample not in self.vcf.samples:
raise ValueError(f'Sample \'{sample}\' '
'not present in VCF file')

if phase not in (0, 1):
raise ValueError('phase argument must be in (0, 1)'
' if sample is not None')

# Interval is 0-based, cyvcf2 positions are 1-based: need to add 1
variants = self._get_sample_variants(
self.vcf(
f'{interval.chrom}:'
+ f'{interval.start + 1}-{interval.end + 1}'
),
sample,
phase
)

return super(SampleSeqExtractor, self).extract(
interval, variants, anchor, fixed_len, **kwargs)

def _get_sample_variants(self, variants, sample, phase):
"""Given a list of `cyvcf2.Variant`, returns all those present for a
given sample and phase and converts them to
`kipoiseq.dataclasses.Variant`

Args:
variants: List of `cyvcf2.Variant`, Variants of interest
sample: `str`, Sample for which to filter genotypes
phase: `0` or `1`, Phase for which to filter genotypes

Returns:
List of `kipoiseq.dataclasses.Variant`
"""
sample_index = self._sample_indices[sample]
return [
Variant.from_cyvcf(v) for v in variants
if v.genotypes[sample_index][phase]
]