Skip to content

Commit

Permalink
Basic prototype for plink
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Feb 7, 2024
1 parent 9f4d26e commit f45089e
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 6 deletions.
165 changes: 160 additions & 5 deletions sgkit/io/vcf/vcf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import tqdm
import zarr

import bed_reader


# from sgkit.io.utils import FLOAT32_MISSING, str_is_int
from sgkit.io.utils import (
Expand All @@ -46,6 +48,7 @@
)


# TODO rename to wait_and_check_futures
def flush_futures(futures):
# Make sure previous futures have completed
for future in cf.as_completed(futures):
Expand Down Expand Up @@ -661,7 +664,7 @@ def convert(
)

global progress_counter
progress_counter = multiprocessing.Value("i", 0)
progress_counter = multiprocessing.Value("Q", 0)

# start update progress bar process
bar_thread = None
Expand Down Expand Up @@ -1111,7 +1114,6 @@ def encode_genotypes(self, pcvcf):
# FIXME set gt-mask
# gt_mask.buff[j: len(value) - 1] =


j += 1
if j == chunk_length:
flush_futures(futures)
Expand Down Expand Up @@ -1230,14 +1232,14 @@ def encode_id(self, pcvcf):
col = pcvcf.columns["ID"]
id_array = self.root["variant_id"]
id_mask_array = self.root["variant_id_mask"]
id_buff = np.full_like(id_array, '')
id_buff = np.full_like(id_array, "")
id_mask_buff = np.zeros_like(id_mask_array)

for j, value in enumerate(col.values()):
if value is not None:
id_buff[j] = value
else:
id_buff[j] = "." # TODO is this correct??
id_buff[j] = "." # TODO is this correct??
id_mask_buff[j] = True

id_array[:] = id_buff
Expand All @@ -1246,7 +1248,6 @@ def encode_id(self, pcvcf):
with progress_counter.get_lock():
progress_counter.value += col.vcf_field.summary.uncompressed_size


@staticmethod
def convert(
pcvcf, path, conversion_spec, *, worker_processes=1, show_progress=False
Expand Down Expand Up @@ -1378,3 +1379,157 @@ def convert_vcf(
worker_processes=worker_processes,
show_progress=show_progress,
)


def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_variant):
bed = bed_reader.open_bed(bed_path, num_threads=1)

store = zarr.DirectoryStore(zarr_path)
root = zarr.group(store=store)
gt = BufferedArray(root["call_genotype"])
gt_mask = BufferedArray(root["call_genotype_mask"])
gt_phased = BufferedArray(root["call_genotype_phased"])
chunk_length = gt.array.chunks[0]
assert start_variant % chunk_length == 0

buffered_arrays = [gt, gt_phased, gt_mask]

with cf.ThreadPoolExecutor(max_workers=8) as executor:
futures = []

start = start_variant
while start < end_variant:
stop = min(start + chunk_length, end_variant)
bed_chunk = bed.read(index=slice(start, stop), dtype="int8").T
# Note could do this without iterating over rows, but it's a bit
# simpler and the bottleneck is in the encoding step anyway. It's
# also nice to have updates on the progress monitor.
for j, values in enumerate(bed_chunk):
dest = gt.buff[j]
dest[values == -127] = -1
dest[values == 2] = 1
dest[values == 1, 0] = 1
gt_phased.buff[j] = False
gt_mask.buff[j] = dest == -1
with progress_counter.get_lock():
progress_counter.value += 1

assert j <= chunk_length
flush_futures(futures)
for ba in buffered_arrays:
futures.extend(
async_flush_array(executor, ba.buff[:j], ba.array, start)
)
ba.swap_buffers()
start = stop
flush_futures(futures)


def convert_plink(
bed_path,
zarr_path,
*,
show_progress,
worker_processes=1,
chunk_length=None,
chunk_width=None,
):
bed = bed_reader.open_bed(bed_path, num_threads=1)
n = bed.iid_count
m = bed.sid_count
del bed

# FIXME
if chunk_width is None:
chunk_width = 1000
if chunk_length is None:
chunk_length = 10_000

store = zarr.DirectoryStore(zarr_path)
root = zarr.group(store=store, overwrite=True)

ploidy = 2
shape = [m, n]
chunks = [chunk_length, chunk_width]
dimensions = ["variants", "samples"]

a = root.empty(
"call_genotype_phased",
dtype="bool",
shape=list(shape),
chunks=list(chunks),
compressor=default_compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)

shape += [ploidy]
dimensions += ["ploidy"]
a = root.empty(
"call_genotype",
dtype="i8",
shape=list(shape),
chunks=list(chunks),
compressor=default_compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)

a = root.empty(
"call_genotype_mask",
dtype="bool",
shape=list(shape),
chunks=list(chunks),
compressor=default_compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = list(dimensions)

global progress_counter
progress_counter = multiprocessing.Value("Q", 0)

# start update progress bar process
bar_thread = None
if show_progress:
bar_thread = threading.Thread(
target=update_bar,
args=(progress_counter, m, "Write", "vars"),
name="progress",
daemon=True,
)
bar_thread.start()

num_chunks = m // chunk_length
worker_processes = min(worker_processes, num_chunks)
if num_chunks == 1 or worker_processes == 1:
partitions = [(0, m)]
else:
# Generate num_workers partitions
# TODO finer grained might be better.
partitions = []
chunk_boundaries = [
p[0] for p in np.array_split(np.arange(num_chunks), worker_processes)
]
for j in range(len(chunk_boundaries) - 1):
start = chunk_boundaries[j] * chunk_length
end = chunk_boundaries[j + 1] * chunk_length
end = min(end, m)
partitions.append((start, end))
last_stop = partitions[-1][-1]
if last_stop != m:
partitions.append((last_stop, m))
# print(partitions)

with cf.ProcessPoolExecutor(
max_workers=worker_processes,
initializer=init_workers,
initargs=(progress_counter,),
) as executor:
futures = [
executor.submit(
encode_bed_partition_genotypes, bed_path, zarr_path, start, end
)
for start, end in partitions
]
flush_futures(futures)
# print("progress counter = ", m, progress_counter.value)
assert progress_counter.value == m

# print(root["call_genotype"][:])
19 changes: 18 additions & 1 deletion vcf2zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def genspec(columnarised):
json.dump(spec.asdict(), stream, indent=4)



@click.command
@click.argument("columnarised", type=click.Path())
@click.argument("zarr_path", type=click.Path())
Expand Down Expand Up @@ -86,6 +85,23 @@ def convert(vcfs, out_path):
cnv.convert_vcf(vcfs, out_path, show_progress=True)


@click.command
@click.argument("plink", type=click.Path())
@click.argument("out_path", type=click.Path())
@click.option("-p", "--worker-processes", type=int, default=1)
@click.option("--chunk-width", type=int, default=None)
@click.option("--chunk-length", type=int, default=None)
def convert_plink(plink, out_path, worker_processes, chunk_width, chunk_length):
cnv.convert_plink(
plink,
out_path,
show_progress=True,
worker_processes=worker_processes,
chunk_width=chunk_width,
chunk_length=chunk_length,
)


@click.group()
def cli():
pass
Expand All @@ -96,6 +112,7 @@ def cli():
cli.add_command(genspec)
cli.add_command(to_zarr)
cli.add_command(convert)
cli.add_command(convert_plink)

if __name__ == "__main__":
cli()

0 comments on commit f45089e

Please sign in to comment.