Coverage for src/pangwas/pangwas.py: 97%
3607 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-25 21:02 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-25 21:02 +0000
1#!/usr/bin/env python3
3# Base python packages
4from collections import OrderedDict, Counter
5import copy
6import itertools
7import logging
8import math
9import os
10import re
11import shutil
12import subprocess
13import sys
14import textwrap
16from tqdm import tqdm
18NUCLEOTIDES = ["A", "C", "G", "T"]
20# Default arguments for programs
21CLUSTER_ARGS = "-k 13 --min-seq-id 0.90 -c 0.90 --cluster-mode 2 --max-seqs 300"
22DEFRAG_ARGS = "-k 13 --min-seq-id 0.90 -c 0.90 --cov-mode 1"
23ALIGN_ARGS = "--adjustdirection --localpair --maxiterate 1000 --addfragments"
24TREE_ARGS = "-safe -m MFP"
25GWAS_ARGS = "--lmm"
27LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()
28logging.basicConfig(level=LOGLEVEL, stream=sys.stdout, format='%(asctime)s %(funcName)20s %(levelname)8s: %(message)s')
30# Notify users when they use an algorithm from ppanggolin
31PPANGGOLIN_NOTICE = """
32 The defrag algorithm is adapted from PPanGGOLiN (v2.2.0) which is
33 distributed under the CeCILL FREE SOFTWARE LICENSE AGREEMENT LABGeM.
34 - Please cite: Gautreau G et al. (2020) PPanGGOLiN: Depicting microbial diversity via a partitioned pangenome graph.
35 PLOS Computational Biology 16(3): e1007732. https://doi.org/10.1371/journal.pcbi.1007732
36 - PPanGGOLiN license: https://github.com/labgem/PPanGGOLiN/blob/2.2.0/LICENSE.txt
37 - Defrag algorithm source: https://github.com/labgem/PPanGGOLiN/blob/2.2.0/ppanggolin/cluster/cluster.py#L317
38"""
40pangwas_description = "pangenome wide association study (panGWAS) pipeline."
41annotate_description = "Annotate genomic assemblies with bakta."
42extract_description = "Extract sequences and annotations from GFF files."
43collect_description = "Collect extracted sequences from multiple samples into one file."
44cluster_description = "Cluster nucleotide sequences with mmseqs."
45defrag_description = "Defrag clusters by associating fragments with their parent cluster."
46summarize_description = "Summarize clusters according to their annotations."
47align_description = "Align clusters using mafft and create a pangenome alignment."
48snps_description = "Extract SNPs from a pangenome alignment."
49structural_description = "Extract structural variants from cluster alignments."
50presence_absence_description = "Extract presence absence of clusters."
51combine_description = "Combine variants from multiple Rtab files."
52table_to_rtab_description = "Convert a TSV/CSV table to an Rtab file based on regex filters."
53tree_description = "Estimate a maximum-likelihood tree with IQ-TREE."
54root_tree_description = "Root tree on outgroup taxa."
55binarize_description = "Convert a categorical column to multiple binary (0/1) columns."
56vcf_to_rtab_description = "Convert a VCF file to an Rtab file."
57gwas_description = "Run genome-wide association study (GWAS) tests with pyseer."
58heatmap_description = "Plot a heatmap of variants alongside a tree."
59manhattan_description = "Plot the distribution of variant p-values across the genome."
61def get_options(args:str=None):
62 import argparse
64 sys_argv_original = sys.argv
66 if args != None:
67 sys.argv = args.split(" ")
69 description = textwrap.dedent(
70 f"""\
71 {pangwas_description}
73 ANNOTATE
74 annotate: {annotate_description}
76 CLUSTER
77 extract: {extract_description}
78 collect: {collect_description}
79 cluster: {cluster_description}
80 defrag: {defrag_description}
81 summarize: {summarize_description}
83 ALIGN
84 align: {align_description}
86 VARIANTS
87 structural: {structural_description}
88 snps: {snps_description}
89 presence_absence: {presence_absence_description}
91 TREE
92 tree: {tree_description}
94 GWAS
95 gwas: {gwas_description}
97 PLOT
98 manhattan: {manhattan_description}
99 heatmap: {heatmap_description}
101 UTILITY
102 root_tree: {root_tree_description}
103 binarize: {binarize_description}
104 table_to_rtab: {table_to_rtab_description}
105 vcf_to_rtab: {vcf_to_rtab_description}
106 """)
107 parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter)
108 parser.add_argument('--version', help='Display program version.', action="store_true")
109 subbcommands = parser.add_subparsers(dest="subcommand")
111 # -------------------------------------------------------------------------
112 # Annotate
114 description = textwrap.dedent(
115 f"""\
116 {annotate_description}
118 Takes as input a FASTA file of genomic assemblies. Outputs a GFF file
119 of annotations, among many other formats from bakta.
121 All additional arguments with be passed to the `bakta` CLI.
123 Examples:
124 > pangwas annotate --fasta sample1.fasta --db database/bakta
125 > pangwas annotate --fasta sample1.fasta --db database/bakta --sample sample1 --threads 2 --genus Streptococcus
126 """)
128 annotate_parser = subbcommands.add_parser('annotate', description = description, formatter_class=argparse.RawTextHelpFormatter)
130 annotate_req_parser = annotate_parser.add_argument_group("required arguments")
131 annotate_req_parser.add_argument("--fasta", required=True, help='Input FASTA sequences.')
132 annotate_req_parser.add_argument("--db", required=True, help='bakta database directory.')
134 annotate_out_opt_parser = annotate_parser.add_argument_group("optional output arguments")
135 annotate_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".")
136 annotate_out_opt_parser.add_argument('--prefix', help='Output file prefix. If not provided, will be parsed from the fasta file name.')
137 annotate_out_opt_parser.add_argument("--tmp", help='Temporary directory.')
139 annotate_opt_parser = annotate_parser.add_argument_group("optional arguments")
140 annotate_opt_parser.add_argument("--sample", help='Sample identifier. If not provided, will be parsed from the fasta file name.')
141 annotate_opt_parser.add_argument("--threads", help='CPU threads for bakta. (default: 1)', default=1)
143 # -------------------------------------------------------------------------
144 # Extract Sequences
146 description = textwrap.dedent(
147 f"""\
148 {extract_description}
150 Takes as input a GFF annotations file. If sequences are not included, a FASTA
151 of genomic contigs must also be provided. Both annotated and unannotated regions
152 will be extracted. Outputs a TSV table of extracted sequence regions.
154 Examples:
155 > pangwas extract --gff sample1.gff3
156 > pangwas extract --gff sample1.gff3 --fasta sample1.fasta --min-len 10
157 """)
159 extract_parser = subbcommands.add_parser('extract', description = description, formatter_class=argparse.RawTextHelpFormatter)
161 extract_req_parser = extract_parser.add_argument_group("required arguments")
162 extract_req_parser.add_argument('--gff', required=True, help='Input GFF annotations.')
164 extract_out_opt_parser = extract_parser.add_argument_group("optional output arguments")
165 extract_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".")
166 extract_out_opt_parser.add_argument('--prefix', help='Output file prefix. If not provided, will be parsed from the gff file name.')
168 extract_opt_parser = extract_parser.add_argument_group("optional arguments")
169 extract_opt_parser.add_argument('--fasta', help='Input FASTA sequences, if not provided at the end of the GFF.')
170 extract_opt_parser.add_argument('--max-len', help='Maximum length of sequences to extract (default: 100000).', type=int, default=100000)
171 extract_opt_parser.add_argument('--min-len', help='Minimum length of sequences to extract (default: 20).', type=int, default=20)
172 extract_opt_parser.add_argument('--sample', help='Sample identifier to use. If not provided, is parsed from the gff file name.')
173 extract_opt_parser.add_argument('--regex', help='Only extract gff lines that match this regular expression.')
175 # -------------------------------------------------------------------------
176 # Collect Sequences
178 description = textwrap.dedent(
179 f"""\
180 {collect_description}
182 Takes as input multiple TSV files from extract, which can be supplied
183 as either space separate paths, or a text file containing paths.
184 Duplicate sequence IDs will be identified and given the suffix '.#'.
185 Outputs concatenated FASTA and TSV files.
187 Examples:
188 > pangwas collect --tsv sample1.tsv sample2.tsv sample3.tsv sample4.tsv
189 > pangwas collect --tsv-paths tsv_paths.txt
190 """)
192 collect_parser = subbcommands.add_parser('collect', description = description, formatter_class=argparse.RawTextHelpFormatter)
194 collect_req_parser = collect_parser.add_argument_group("required arguments (mutually-exclusive)")
195 collect_input = collect_req_parser.add_mutually_exclusive_group(required=True)
196 collect_input.add_argument('--tsv', help='TSV files from the extract subcommand.', nargs='+')
197 collect_input.add_argument('--tsv-paths', help='TXT file containing paths to TSV files.')
199 collect_out_opt_parser = collect_parser.add_argument_group("optional output arguments")
200 collect_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".")
201 collect_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
203 # -------------------------------------------------------------------------
204 # Cluster Sequences (mmseqs)
206 description = textwrap.dedent(
207 f"""\
208 {cluster_description}
210 Takes as input a FASTA file of sequences for clustering. Calls MMSeqs2
211 to cluster sequences and identify a representative sequence. Outputs a
212 TSV table of sequence clusters and a FASTA of representative sequences.
214 Any additional arguments will be passed to `mmseqs cluster`. If no additional
215 arguments are used, the following default args will apply:
216 {CLUSTER_ARGS}
218 Examples:
219 > pangwas cluster --fasta collect.fasta
220 > pangwas cluster --fasta collect.fasta --threads 2 -k 13 --min-seq-id 0.90 -c 0.90
221 """)
223 cluster_parser = subbcommands.add_parser('cluster', description = description, formatter_class=argparse.RawTextHelpFormatter)
225 cluster_req_parser = cluster_parser.add_argument_group("required arguments")
226 cluster_req_parser.add_argument('-f', '--fasta', required=True, help='FASTA file of input sequences to cluster.')
228 cluster_out_opt_parser = cluster_parser.add_argument_group("optional output arguments")
229 cluster_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".")
230 cluster_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
231 cluster_out_opt_parser.add_argument('--tmp', help='Tmp directory (default: tmp).', default="tmp")
233 cluster_opt_parser = cluster_parser.add_argument_group("optional arguments")
234 cluster_opt_parser.add_argument('--memory', help='Memory limit for mmseqs split. (default: 1G)', default="1G")
235 cluster_opt_parser.add_argument('--no-clean', help="Don't clean up intermediate files.", action="store_false", dest="clean")
236 cluster_opt_parser.add_argument('--threads', help='CPU threads for mmseqs. (default: 1)', default=1)
238 # -------------------------------------------------------------------------
239 # Defrag clusters
241 description = textwrap.dedent(
242 f"""\
243 {defrag_description}
245 Takes as input the TSV clusters and FASTA representatives from cluster.
246 Outputs a new cluster table and representative sequences fasta.
248 {PPANGGOLIN_NOTICE}
250 Any additional arguments will be passed to `mmseqs search`. If no additional
251 arguments are used, the following default args will apply:
252 {DEFRAG_ARGS}
254 Examples:
255 > pangwas defrag --clusters clusters.tsv --representative representative.fasta --prefix defrag
256 > pangwas defrag --clusters clusters.tsv --representative representative.fasta --prefix defrag --threads 2 -k 13 --min-seq-id 0.90 -c 0.90
257 """)
259 defrag_parser = subbcommands.add_parser('defrag', description=description, formatter_class=argparse.RawTextHelpFormatter)
260 defrag_req_parser = defrag_parser.add_argument_group("required arguments")
261 defrag_req_parser.add_argument('--clusters', required=True, help='TSV file of clusters from mmseqs.')
262 defrag_req_parser.add_argument('--representative', required=True, help='FASTA file of cluster representative sequences.')
264 defrag_out_opt_parser = defrag_parser.add_argument_group("optional output arguments")
265 defrag_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".")
266 defrag_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
267 defrag_out_opt_parser.add_argument('--tmp', help='Tmp directory (default: tmp).', default="tmp")
269 defrag_opt_parser = defrag_parser.add_argument_group("optional arguments")
270 defrag_opt_parser.add_argument('--memory', help='Memory limit for mmseqs split. (default: 2G)', default="2G")
271 defrag_opt_parser.add_argument('--no-clean', help="Don't clean up intermediate files.", action="store_false", dest="clean")
272 defrag_opt_parser.add_argument('--threads', help='CPU threads for mmseqs. (default: 1)', default=1)
274 # -------------------------------------------------------------------------
275 # Summarize clusters
277 description = textwrap.dedent(
278 f"""\
279 {summarize_description}
281 Takes as input the TSV table from collect, and the clusters table from
282 either cluster or defrag. Outputs a TSV table of summarized clusters
283 with their annotations.
285 Examples:
286 > pangwas summarize --clusters defrag.clusters.tsv --regions regions.tsv
287 """)
289 summarize_parser = subbcommands.add_parser('summarize', description=description, formatter_class=argparse.RawTextHelpFormatter)
290 summarize_req_parser = summarize_parser.add_argument_group("required arguments")
291 summarize_req_parser.add_argument('--clusters', required=True, help='TSV file of clusters from cluster or defrag.')
292 summarize_req_parser.add_argument('--regions', required=True, help='TSV file of sequence regions from collect.')
294 summarize_opt_parser = summarize_parser.add_argument_group("optional arguments")
295 summarize_opt_parser.add_argument("--max-product-len", help='Truncate the product description to this length if used as an identifie. (default: 50)', type=int, default=50)
296 summarize_opt_parser.add_argument("--min-samples", help='Cluster must be observed in at least this many samples to be summarized.', type=int, default=1)
297 summarize_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
298 summarize_opt_parser.add_argument('--prefix', help='Prefix for output files.')
299 summarize_opt_parser.add_argument('--threshold', help='Required this proportion of samples to have annotations in agreement. (default: 0.5)', type=float, default=0.5)
301 # -------------------------------------------------------------------------
302 # Align clusters
304 description = textwrap.dedent(
305 f"""\
306 {align_description}
308 Takes as input the clusters from summarize and the sequence regions
309 from collect. Outputs multiple sequence alignments per cluster
310 as well as a pangenome alignment of concatenated clusters.\n
312 Any additional arguments will be passed to `mafft`. If no additional
313 arguments are used, the following default args will apply:
314 {ALIGN_ARGS}
316 Examples:
317 > pangwas align --clusters clusters.tsv --regions regions.tsv
318 > pangwas align --clusters clusters.tsv --regions regions.tsv --threads 2 --exclude-singletons --localpair --maxiterate 100
319 """)
321 align_parser = subbcommands.add_parser('align', description=description, formatter_class=argparse.RawTextHelpFormatter)
323 align_req_parser = align_parser.add_argument_group("required arguments")
324 align_req_parser.add_argument('--clusters', help='TSV of clusters from summarize.', required=True)
325 align_req_parser.add_argument('--regions', help='TSV of sequence regions from collect.', required=True)
327 align_out_opt_parser = align_parser.add_argument_group("optional output arguments")
328 align_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
329 align_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
331 align_opt_parser = align_parser.add_argument_group("optional arguments")
332 align_opt_parser.add_argument('--threads', help='CPU threads for running mafft in parallel. (default: 1)', type=int, default=1)
333 align_opt_parser.add_argument('--exclude-singletons', help='Exclude clusters found in only 1 sample.', action="store_true")
335 # -------------------------------------------------------------------------
336 # Variants: Structural
338 description = textwrap.dedent(
339 f"""\
340 {structural_description}
342 Takes as input the summarized clusters and their individual alignments.
343 Outputs an Rtab file of structural variants.
345 Examples:
346 > pangwas structural --clusters clusters.tsv --alignments alignments
347 > pangwas structural --clusters clusters.tsv --alignments alignments --min-len 100 --min-indel-len 10
348 """)
350 structural_parser = subbcommands.add_parser('structural', description=description, formatter_class=argparse.RawTextHelpFormatter)
352 structural_req_parser = structural_parser.add_argument_group("required arguments")
353 structural_req_parser.add_argument('--clusters', required=True, help='TSV of clusters from summarize.')
354 structural_req_parser.add_argument('--alignments', required=True, help='Directory of cluster alignments (not consensus alignments!).')
356 structural_out_opt_parser = structural_parser.add_argument_group("optional output arguments")
357 structural_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
358 structural_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
360 structural_opt_parser = structural_parser.add_argument_group("optional arguments")
361 structural_opt_parser.add_argument('--min-len', help='Minimum variant length. (default: 10)', type=int, default=10)
362 structural_opt_parser.add_argument('--min-indel-len', help='Minimum indel length. (default: 3)', type=int, default=3)
365 # -------------------------------------------------------------------------
366 # Variants: SNPs
368 description = textwrap.dedent(
369 f"""\
370 {snps_description}
372 Takes as input the pangenome alignment fasta, bed, and consensus file from align.
373 Outputs an Rtab file of SNPs.
375 Examples:
376 > pangwas snps --alignment pangenome.aln --bed pangenome.bed --consensus pangenome.consensus.fasta
377 > pangwas snps --alignment pangenome.aln --bed pangenome.bed --consensus pangenome.consensus.fasta --structural structural.Rtab --core 0.90 --indel-window 3 --snp-window 10
378 """)
380 snps_parser = subbcommands.add_parser('snps', description=description, formatter_class=argparse.RawTextHelpFormatter)
382 snps_req_parser = snps_parser.add_argument_group("required arguments")
383 snps_req_parser.add_argument('--alignment', required=True, help='Fasta sequences alignment.')
384 snps_req_parser.add_argument('--bed', required=True, help='Bed file of coordinates.')
385 snps_req_parser.add_argument('--consensus', required=True, help='Fasta consensus/representative sequence.')
387 snps_rec_parser = snps_parser.add_argument_group("optional but recommended arguments")
388 snps_rec_parser.add_argument('--structural', help='Rtab from the structural subcommand, used to avoid treating terminal ends as indels.')
390 snps_out_opt_parser = snps_parser.add_argument_group("optional output arguments")
391 snps_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
392 snps_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
394 snps_opt_parser = snps_parser.add_argument_group("optional arguments")
395 snps_opt_parser.add_argument('--core', help='Core threshold for calling core SNPs. (default: 0.95)', type=float, default=0.95)
396 snps_opt_parser.add_argument('--indel-window', help='Exclude SNPs that are within this proximity to indels. (default: 0)', type=int, default=0)
397 snps_opt_parser.add_argument('--snp-window', help='Exclude SNPs that are within this proximity to another SNP. (default: 0)', type=int, default=0)
399 # -------------------------------------------------------------------------
400 # Variants: Presence Absence
402 description = textwrap.dedent(
403 f"""\
404 {presence_absence_description}
406 Takes as input a TSV of summarized clusters from summarize.
407 Outputs an Rtab file of cluster presence/absence.
409 Examples:
410 > pangwas presence_absence --clusters clusters.tsv
411 """)
413 pres_parser = subbcommands.add_parser('presence_absence', description=description, formatter_class=argparse.RawTextHelpFormatter)
414 pres_req_parser = pres_parser.add_argument_group("required arguments")
415 pres_req_parser.add_argument('--clusters', help='TSV of clusters from summarize.', required=True)
417 pres_out_opt_parser = pres_parser.add_argument_group("optional output arguments")
418 pres_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
419 pres_out_opt_parser.add_argument('--prefix', help='Prefix for output files.')
421 # -------------------------------------------------------------------------
422 # Variants: Combine
424 description = textwrap.dedent(
425 f"""\
426 {combine_description}
428 Takes as input a list of file paths to Rtab files. Outputs an Rtab file with
429 the variants concatenated, ensuring consistent ordering of the sample columns.
431 Examples:
432 > pangwas combine --rtab snps.Rtab presence_absence.Rtab
433 > pangwas combine --rtab snps.Rtab structural.Rtab presence_absence.Rtab
434 """)
435 combine_parser = subbcommands.add_parser('combine', description=description, formatter_class=argparse.RawTextHelpFormatter)
437 combine_req_parser = combine_parser.add_argument_group("required arguments")
438 combine_req_parser.add_argument('--rtab', required=True, help="Rtab variants files.", nargs='+')
440 combine_opt_parser = combine_parser.add_argument_group("optional arguments")
441 combine_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
442 combine_opt_parser.add_argument('--prefix', help='Prefix for output files.')
444 # -------------------------------------------------------------------------
445 # Variants: Table to Rtab
447 description = textwrap.dedent(
448 f"""\
449 {table_to_rtab_description}
451 Takes as input a TSV/CSV table to convert, and a TSV/CSV of regex filters.
452 The filter table should have the header: column, regex, name. Where column
453 is the 'column' to search, 'regex' is the regular expression pattern, and
454 'name' is how the output variant should be named in the Rtab.
456 An example `filter.tsv` might look like this:
458 column regex name
459 assembly .*sample2.* sample2
460 lineage .*2.* lineage_2
462 Where the goal is to filter the assembly and lineage columns for particular values.
464 Examples:
465 > pangwas table_to_rtab --table samplesheet.csv --filter filter.tsv
466 """
467 )
469 table_to_rtab_parser = subbcommands.add_parser('table_to_rtab', description=description, formatter_class=argparse.RawTextHelpFormatter)
470 table_to_rtab_req_parser = table_to_rtab_parser.add_argument_group("required arguments")
471 table_to_rtab_req_parser.add_argument('--table', required=True, help='TSV or CSV table.')
472 table_to_rtab_req_parser.add_argument('--filter', required=True, help='TSV or CSV filter table.')
474 table_to_rtab_opt_parser = table_to_rtab_parser.add_argument_group("optional arguments")
475 table_to_rtab_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
476 table_to_rtab_opt_parser.add_argument('--prefix', help='Prefix for output files.')
478 # -------------------------------------------------------------------------
479 # Variants: VCF to Rtab
481 description = textwrap.dedent(
482 f"""\
483 {vcf_to_rtab_description}
485 Takes as input a VCF file to convert to a SNPs Rtab file.
487 Examples:
488 > pangwas vcf_to_rtab --vcf snps.vcf
489 """
490 )
492 vcf_to_rtab_parser = subbcommands.add_parser('vcf_to_rtab', description=description, formatter_class=argparse.RawTextHelpFormatter)
493 vcf_to_rtab_req_parser = vcf_to_rtab_parser.add_argument_group("required arguments")
494 vcf_to_rtab_req_parser.add_argument('--vcf', required=True, help='VCF file.')
496 vcf_to_rtab_opt_parser = vcf_to_rtab_parser.add_argument_group("optional arguments")
497 vcf_to_rtab_opt_parser.add_argument("--bed", help='BED file with names by coordinates.')
498 vcf_to_rtab_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
499 vcf_to_rtab_opt_parser.add_argument('--prefix', help='Prefix for output files.')
501 # -------------------------------------------------------------------------
502 # Tree
504 description = textwrap.dedent(
505 f"""\
506 {tree_description}
508 Takes as input a multiple sequence alignment in FASTA format. If a SNP
509 alignment is provided, an optional text file of constant sites can be
510 included for correction. Outputs a maximum-likelihood tree, as well as
511 additional rooted trees if an outgroup is specified in the iqtree args.
513 Any additional arguments will be passed to `iqtree`. If no additional
514 arguments are used, the following default args will apply:
515 {TREE_ARGS}
517 Examples:
518 > pangwas tree --alignment snps.core.fasta --constant-sites snps.constant_sites.txt
519 > pangwas tree --alignment pangenome.aln --threads 4 -o sample1 --ufboot 1000
520 """
521 )
522 tree_parser = subbcommands.add_parser('tree', description=description, formatter_class=argparse.RawTextHelpFormatter)
524 tree_req_parser = tree_parser.add_argument_group("required arguments")
525 tree_req_parser.add_argument('--alignment', required=True, help='Multiple sequence alignment.')
527 tree_opt_parser = tree_parser.add_argument_group("optional arguments")
528 tree_opt_parser.add_argument('--constant-sites', help='Text file containing constant sites correction for SNP alignment.')
529 tree_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
530 tree_opt_parser.add_argument('--prefix', help='Prefix for output files.')
531 tree_opt_parser.add_argument('--threads', help='CPU threads for IQTREE. (default: 1)', default=1)
533 # -------------------------------------------------------------------------
534 # Root tree
536 description = textwrap.dedent(
537 f"""\
538 {root_tree_description}
540 Takes as input a path to a phylogenetic tree and outgroup taxa. This is
541 a utility script that is meant to fix IQ-TREE's creation of a multifurcated
542 root node. It will position the root node along the midpoint of the branch
543 between the outgroup taxa and all other samples. If no outgroup is selected,
544 the tree will be rooted using the first taxa. Outputs a new tree in the specified
545 tree format.
547 Note: This functionality is already included in the tree subcommand.
549 Examples:
550 > pangwas root_tree --tree tree.treefile
551 > pangwas root_tree --tree tree.treefile --outgroup sample1
552 > pangwas root_tree --tree tree.treefile --outgroup sample1,sample4
553 > pangwas root_tree --tree tree.nex --outgroup sample1 --tree-format nexus
554 """
555 )
556 root_tree_parser = subbcommands.add_parser('root_tree', description=description, formatter_class=argparse.RawTextHelpFormatter)
558 root_tree_req_parser = root_tree_parser.add_argument_group("required arguments")
559 root_tree_req_parser.add_argument('--tree', help='Path to phylogenetic tree.')
561 root_tree_opt_parser = root_tree_parser.add_argument_group("optional arguments")
562 root_tree_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
563 root_tree_opt_parser.add_argument("--outgroup", help='Outgroup taxa as CSV string. If not specified, roots on first taxon.')
564 root_tree_opt_parser.add_argument('--prefix', help='Prefix for output files.')
565 root_tree_opt_parser.add_argument('--tree-format', help='Tree format. (default: newick)', type=str, default="newick")
567 # -------------------------------------------------------------------------
568 # GWAS: Binarize
570 description = textwrap.dedent(
571 f"""\
572 {binarize_description}
574 Takes as input a table (TSV or CSV) and the name of a column to binarize.
575 Outputs a new table with separate columns for each categorical value.
577 Any additional arguments will be passed to `pyseer`.
579 Examples:
580 pangwas binarize --table samplesheet.csv --column lineage --output lineage.binarize.csv
581 pangwas binarize --table samplesheet.tsv --column outcome --output outcome.binarize.tsv
582 """)
584 binarize_parser = subbcommands.add_parser('binarize', description=description, formatter_class=argparse.RawTextHelpFormatter)
585 binarize_req_parser = binarize_parser.add_argument_group("required arguments")
586 binarize_req_parser.add_argument('--table', required=True, help='TSV or CSV table.')
587 binarize_req_parser.add_argument('--column', required=True, help='Column name to binarize.')
589 binarize_opt_parser = binarize_parser.add_argument_group("optional arguments")
590 binarize_opt_parser.add_argument('--column-prefix', help='Prefix to add to column names.')
591 binarize_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
592 binarize_opt_parser.add_argument("--output-delim", help='Output delimiter. (default: \t )', default="\t")
593 binarize_opt_parser.add_argument('--prefix', help='Prefix for output files')
594 binarize_opt_parser.add_argument('--transpose', help='Tranpose output table.', action='store_true')
596 # -------------------------------------------------------------------------
597 # GWAS: pyseer
599 description = textwrap.dedent(
600 f"""\
601 {gwas_description}
603 Takes as input the TSV file of summarized clusters from summarize, an Rtab file of variants,
604 a TSV/CSV table of phenotypes, and a column name representing the trait of interest.
605 Outputs tables of locus effects, and optionally lineage effects (bugwas) if specified.
607 Any additional arguments will be passed to `pyseer`. If no additional
608 arguments are used, the following default args will apply:
609 {GWAS_ARGS}
611 Examples:
612 > pangwas gwas --variants combine.Rtab --table samplesheet.csv --column lineage --no-distances
613 > pangwas gwas --variants combine.Rtab --table samplesheet.csv --column resistant --lmm --tree tree.rooted.nwk --clusters clusters.tsv --lineage-column lineage '
614 """)
615 gwas_parser = subbcommands.add_parser('gwas', description=description, formatter_class=argparse.RawTextHelpFormatter)
616 # Required options
617 gwas_req_parser = gwas_parser.add_argument_group("required arguments")
618 gwas_req_parser.add_argument('--variants', required=True, help='Rtab file of variants.')
619 gwas_req_parser.add_argument('--table', required=True, help='TSV or CSV table of phenotypes (variables).')
620 gwas_req_parser.add_argument('--column', required=True, help='Column name of trait (variable) in table.')
621 # Lineage effects options
622 gwas_lin_parser = gwas_parser.add_argument_group("optional arguments (bugwas)")
623 gwas_lin_parser.add_argument('--lineage-column', help='Column name of lineages in table. Enables bugwas lineage effects.')
624 # LMM options
625 gwas_lmm_parser = gwas_parser.add_argument_group("optional arguments (lmm)")
626 gwas_lmm_parser.add_argument('--no-midpoint', help='Disable midpoint rooting of the tree for the kinship matrix.', dest="midpoint", action="store_false")
627 gwas_lmm_parser.add_argument('--tree', help='Newick phylogenetic tree. Not required if pyseer args includes --no-distances.')
628 # Data Options
629 gwas_data_parser = gwas_parser.add_argument_group("optional arguments (data)")
630 gwas_data_parser.add_argument('--continuous', help='Treat the column trait as a continuous variable.', action="store_true")
631 gwas_data_parser.add_argument('--exclude-missing', help='Exclude samples missing phenotype data.', action="store_true")
633 gwas_opt_parser = gwas_parser.add_argument_group("optional arguments")
634 gwas_opt_parser.add_argument('--clusters', help='TSV of clusters from summarize.')
635 gwas_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
636 gwas_opt_parser.add_argument('--prefix', help='Prefix for output files.')
637 gwas_opt_parser.add_argument('--threads', help='CPU threads for pyseer. (default: 1)', default=1)
639 # -------------------------------------------------------------------------
640 # Plot
642 description = textwrap.dedent(
643 f"""\
644 {heatmap_description}
646 Takes as input a table of variants and/or a newick tree. The table can be either
647 an Rtab file, or the locus effects TSV output from the gwas subcommand.
648 If both a tree and a table are provided, the tree will determine the sample order
649 and arrangement. If just a table is provided, sample order will follow the
650 order of the sample columns. A TXT of focal sample IDs can also be supplied
651 with one sample ID per line. Outputs a plot in SVG and PNG format.
653 Examples:
654 > pangwas heatmap --tree tree.rooted.nwk
655 > pangwas heatmap --rtab combine.Rtab
656 > pangwas heatmap --gwas resistant.locus_effects.significant.tsv
657 > pangwas heatmap --tree tree.rooted.nwk --rtab combine.Rtab --focal focal.txt
658 > pangwas heatmap --tree tree.rooted.nwk --gwas resistant.locus_effects.significant.tsv
659 > pangwas heatmap --tree tree.rooted.nwk --tree-width 500 --png-scale 2.0
660 """
661 )
663 heatmap_parser = subbcommands.add_parser('heatmap', description=description, formatter_class=argparse.RawTextHelpFormatter)
665 heatmap_req_parser = heatmap_parser.add_argument_group("optional variant arguments (mutually-exclusive)")
666 heatmap_variants_input = heatmap_req_parser.add_mutually_exclusive_group(required=False)
667 heatmap_variants_input.add_argument('--gwas', help='TSV table of variants from gwas subcommand.')
668 heatmap_variants_input.add_argument('--rtab', help='Rtab table of variants.')
670 heatmap_tree_parser = heatmap_parser.add_argument_group("optional tree arguments")
671 heatmap_tree_parser.add_argument('--tree', help='Tree file.')
672 heatmap_tree_parser.add_argument('--tree-format', help='Tree format. (default: newick)', type=str, default="newick")
673 heatmap_tree_parser.add_argument('--tree-width', help='Width of the tree in pixels. (default: 200)', type=int, default=200)
674 heatmap_tree_parser.add_argument('--root-branch', help='Root branch length in pixels. (default: 10)', type=int, default=10)
676 heatmap_opt_parser = heatmap_parser.add_argument_group("optional arguments")
677 heatmap_opt_parser.add_argument('--focal', help='TXT file of focal samples.')
678 heatmap_opt_parser.add_argument('--font-size', help='Font size of the tree labels. (default: 16)', type=int, default=16)
679 heatmap_opt_parser.add_argument('--font-family', help='Font family of the tree labels. (default: Roboto)', type=str, default="Roboto")
680 heatmap_opt_parser.add_argument('--heatmap-scale', help='Scaling factor of heatmap boxes relative to the text. (default: 1.5)', type=float, default=1.5)
681 heatmap_opt_parser.add_argument('--margin', help='Margin sizes in pixels. (default: 20)', type=int, default=20)
682 heatmap_opt_parser.add_argument('--min-score', help='Filter GWAS variants by a minimum score (range: -1.0 to 1.0).', type=float)
683 heatmap_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
684 heatmap_opt_parser.add_argument('--png-scale', help='Scaling factor of the PNG file. (default: 2.0)', type=float, default=2.0)
685 heatmap_opt_parser.add_argument('--prefix', help='Prefix for output files.')
686 heatmap_opt_parser.add_argument('--tip-pad', help='Tip padding. (default: 10)', type=int, default=10)
688 # -------------------------------------------------------------------------
689 # Manhattan
691 description = textwrap.dedent(
692 f"""\
693 {manhattan_description}
695 Takes as input a table of locus effects from the subcommand and a bed
696 file such as the one producted by the align subcommand. Outputs a
697 manhattan plot in SVG and PNG format.
699 Examples:
700 > pangwas manhattan --gwas locus_effects.tsv --bed pangenome.bed
701 > pangwas manhattan --gwas locus_effects.tsv --bed pangenome.bed --syntenies chromosome --clusters pbpX --variant-types snp presence_absence
702 """
703 )
705 manhattan_parser = subbcommands.add_parser('manhattan', description=description, formatter_class=argparse.RawTextHelpFormatter)
707 manhattan_req_parser = manhattan_parser.add_argument_group("required arguments")
708 manhattan_req_parser.add_argument('--bed', required=True, help='BED file with region coordinates and names.')
709 manhattan_req_parser.add_argument('--gwas', required=True, help='TSV table of variants from gwas subcommand.')
711 manhattan_opt_parser = manhattan_parser.add_argument_group("optional filter arguments")
712 manhattan_opt_parser.add_argument('--clusters', help='Only plot these clusters. (default: all)', nargs='+', default=["all"])
713 manhattan_opt_parser.add_argument('--syntenies', help='Only plot these synteny blocks. (default: all)', nargs='+', default=["all"])
714 manhattan_opt_parser.add_argument('--variant-types', help='Only plot these variant types. (default: all)', nargs='+', default=["all"])
716 manhattan_opt_parser = manhattan_parser.add_argument_group("optional arguments")
717 manhattan_opt_parser.add_argument('--font-size', help='Font size of the tree labels. (default: 16)', type=int, default=16)
718 manhattan_opt_parser.add_argument('--font-family', help='Font family of the tree labels. (default: Roboto)', type=str, default="Roboto")
719 manhattan_opt_parser.add_argument('--height', help='Plot height in pixels.', type=int, default=500)
720 manhattan_opt_parser.add_argument('--margin', help='Margin sizes in pixels. (default: 20)', type=int, default=20)
721 manhattan_opt_parser.add_argument('--max-blocks', help='Maximum number of blocks to draw before switching to pangenome coordinates. (default: 20)', type=int, default=20)
722 manhattan_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".")
723 manhattan_opt_parser.add_argument('--png-scale', help='Scaling factor of the PNG file. (default: 2.0)', type=float, default=2.0)
724 manhattan_opt_parser.add_argument('--prefix', help='Prefix for output files.')
725 manhattan_opt_parser.add_argument('--prop-x-axis', help='Make x-axis proporional to genomic length.', action="store_true")
726 manhattan_opt_parser.add_argument('--width', help='Plot width in pixels.', type=int, default=1000)
727 manhattan_opt_parser.add_argument('--ymax', help='A -log10 value to use as the y axis max, to synchronize multiple plots.', type=float)
729 # -------------------------------------------------------------------------
730 # Finalize
732 defined, undefined = parser.parse_known_args()
733 if len(undefined) > 0:
734 defined.args = " ".join(undefined)
735 logging.warning(f"Using undefined args: {defined.args}")
737 sys.argv = sys_argv_original
738 return defined
741def check_output_dir(output_dir):
742 if output_dir != None and output_dir != "." and output_dir != "" and not os.path.exists(output_dir):
743 logging.info(f"Creating output directory: {output_dir}")
744 os.makedirs(output_dir)
747def reverse_complement(sequence):
748 lookup = {"A": "T", "C": "G", "G": "C", "T": "A"}
749 split = []
750 for nuc in sequence[::-1]:
751 nuc = nuc.upper()
752 compl = lookup[nuc] if nuc in lookup else nuc
753 split.append(compl)
754 return "".join(split)
757def extract_cli_param(args: str, target: str):
758 """
759 Extract a value flag from a string of CLI arguments.
761 :param args: String of CLI arguments (ex. 'iqtree -T 3 --seed 123').
762 :param args: String of target param (ex. '-T', '--seed')
763 """
764 if not target.endswith(" "):
765 target += " "
766 val = None
767 if target in args:
768 start = args.index(target) + len(target)
769 try:
770 end = args[start:].index(' ')
771 val = args[start:][:end]
772 except ValueError:
773 val = args[start:]
774 return val
777def run_cmd(cmd: str, output=None, err=None, quiet=False, display_cmd=True):
778 """
779 Run a subprocess command.
780 """
781 if display_cmd:
782 logging.info(f"{cmd}")
783 cmd_run = [str(c) for c in cmd.replace("\n", " ").split(" ") if c != ""]
784 try:
785 result = subprocess.run(cmd_run, check=True, capture_output=True, text=True)
786 except subprocess.CalledProcessError as error:
787 logging.error(f"stdout:\n{error.stdout}")
788 logging.error(f"stderr:\n{error.stderr}")
789 raise Exception(f"Error in command: {cmd}\n" + error.stderr)
790 else:
791 if not quiet:
792 logging.info(result.stdout)
793 if output:
794 with open(output, 'w') as outfile:
795 outfile.write(result.stdout)
796 if err:
797 with open(err, 'w') as errfile:
798 errfile.write(result.stderr)
799 return (result.stdout, result.stderr)
801def annotate(
802 fasta: str,
803 db: str,
804 outdir: str = ".",
805 prefix: str = None,
806 sample: str = None,
807 threads: int = 1,
808 tmp: str = None,
809 args: str = None
810 ):
811 """
812 Annotate genomic assemblies with bakta.
814 Takes as input a FASTA file of genomic assemblies. Outputs a GFF file
815 of annotations, among many other formats from bakta.
817 Any additional arguments in `args` will be passed to `bakta`.
819 >>> annotate(fasta="sample1.fasta", db="database/bakta")
820 >>> annotate(fasta="sample2.fasta", db="database/bakta", threads=2, args="--genus Streptococcus")
822 :param fasta: File path to the input FASTA file.
823 :param db: Directory path of the bakta database.
824 :param outdir: Output directory.
825 :param prefix: Prefix for output files.
826 :param sample: Sample identifier for output files.
827 :param threads: CPU threads for bakta.
828 :param tmp: Temporary directory.
829 :param args: Str of additional arguments to pass to bakta
831 :return: Path to the output GFF annotations.
832 """
833 from bakta.main import main as bakta
835 output_dir, tmp_dir = outdir, tmp
837 # Argument Checking
838 args = args if args != None else ""
839 sample = sample if sample else os.path.basename(fasta).split(".")[0]
840 logging.info(f"Using sample identifier: {sample}")
841 prefix = prefix if prefix != None else os.path.basename(fasta).split(".")[0]
842 logging.info(f"Using prefix: {prefix}")
843 if tmp != None:
844 args += f" --tmp-dir {tmp}"
846 # bakta needs the temporary directory to already exist
847 # bakta explicity wants the output directory to NOT exist
848 # so that is why we don't create it
849 check_output_dir(tmp_dir)
851 # Set up cmd to execute, we're not going to use the run_cmd, because
852 # bakta is a python package, and we call it directly! This also has the
853 # benefit of better/immediate logging.
854 cmd = f"bakta --force --prefix {prefix} --db {db} --threads {threads} --output {output_dir} {args} {fasta}"
855 logging.info(f"{cmd}")
857 # Split cmd string and handle problematic whitespace
858 cmd_run = [str(c) for c in cmd.replace("\n", " ").split(" ") if c != ""]
859 # Manually set sys.argv for bakta
860 sys_argv_original = sys.argv
861 sys.argv = cmd_run
862 bakta()
864 # Fix bad quotations
865 for ext in ["embl", "faa", "ffn", "gbff", "gff3", "tsv"]:
866 file_path = os.path.join(output_dir, f"{prefix}.{ext}")
867 logging.info(f"Fixing bad quotations in file: {file_path}")
868 with open(file_path) as infile:
869 content = infile.read()
870 with open(file_path, 'w') as outfile:
871 outfile.write(content.replace('‘', '').replace('’', ''))
873 # Restore original sys.argv
874 sys.argv = sys_argv_original
876 gff_path = os.path.join(output_dir, f"{sample}.gff3")
877 return gff_path
880def extract(
881 gff: str,
882 sample: str = None,
883 fasta: str = None,
884 outdir: str = ".",
885 prefix: str = None,
886 min_len: int = 20,
887 max_len: int = 100000,
888 regex: str = None,
889 args: str = None
890 ) -> OrderedDict:
891 """
892 Extract sequence records from GFF annotations.
894 Takes as input a GFF annotations file. If sequences are not included, a FASTA
895 of genomic contigs must also be provided. Both annotated and unannotated regions
896 will be extracted. Outputs a TSV table of extracted sequence regions.
898 >>> extract(gff='sample1.gff3')
899 >>> extract(gff='sample2.gff3', fasta="sample2.fasta", min_len=10)
901 :param gff: File path to the input GFF file.
902 :param sample: A sample identifier that will be used in output files.
903 :param fasta: File path to the input FASTA file (if GFF doesn't have sequences).
904 :param outdir: Output directory.
905 :param prefix: Output file path prefix.
906 :param min_len: The minimum length of sequences/annotations to extract.
907 :param regex: Only extract lines that match this regex.
908 :param args: Str of additional arguments [not implemented]
910 :return: Sequence records and annotations as an OrderedDict.
911 """
913 from Bio import SeqIO
914 from io import StringIO
915 import re
917 gff_path, fasta_path, output_dir = gff, fasta, outdir
919 # Argument checking
920 args = args if args != None else ""
921 sample = sample if sample != None else os.path.basename(gff).split(".")[0]
922 logging.info(f"Using sample identifier: {sample}")
923 prefix = f"{prefix}" if prefix != None else os.path.basename(gff).split(".")[0]
924 logging.info(f"Using prefix: {prefix}")
925 if "/" in prefix:
926 msg = "Prefix cannot contain slashes (/)"
927 logging.error(msg)
928 raise Exception(msg)
929 min_len = int(min_len)
930 max_len = int(max_len) if max_len != None else None
932 # Check output directory
933 check_output_dir(output_dir)
935 # A visual separator for creating locus names for unannotated regions
936 # ex. 'SAMPL01_00035___SAMPL01_00040' indicating the region
937 # falls between two annotated regions: SAMPL01_00035 and SAMPL01_00040
938 delim = "__"
940 # -------------------------------------------------------------------------
941 # Extract contig sequences
943 logging.info(f"Reading GFF: {gff_path}")
945 contigs = OrderedDict()
946 with open(gff) as infile:
947 gff = infile.read()
948 if "##FASTA" in gff:
949 logging.info(f"Extracting sequences from GFF.")
950 fasta_i = gff.index("##FASTA")
951 fasta_string = gff[fasta_i + len("##FASTA"):].strip()
952 fasta_io = StringIO(fasta_string)
953 records = SeqIO.parse(fasta_io, "fasta")
954 elif fasta_path == None:
955 msg = f"A FASTA file must be provided if no sequences are in the GFF: {gff_path}"
956 logging.error(msg)
957 raise Exception(msg)
959 if fasta_path != None:
960 logging.info(f"Extracting sequences from FASTA: {fasta_path}")
961 if "##FASTA" in gff:
962 logging.warning(f"Sequences found in GFF, fasta will be ignored: {fasta}")
963 else:
964 records = SeqIO.parse(fasta_path, "fasta")
965 fasta_i = None
967 sequences_seen = set()
968 for record in records:
969 if record.id in sequences_seen:
970 msg = f"Duplicate sequence ID found: {record.id}"
971 logging.error(msg)
972 raise Exception(msg)
973 sequences_seen.add(record.id)
974 contigs[record.id] = record.seq
976 # -------------------------------------------------------------------------
977 # Extract annotations
979 logging.info(f"Extracting annotations from gff: {gff_path}")
981 # Parsing GFF is not yet part of biopython, so this is done manually:
982 # https://biopython.org/wiki/GFF_Parsing
984 annotations = OrderedDict()
985 contig = None
986 sequence = ""
988 # Keep track of loci, in case we need to flag duplicates
989 locus_counts = {}
990 gene_counts = {}
991 locus_to_contig = {}
993 comment_contigs = set()
995 for i,line in enumerate(gff.split("\n")):
996 # If we hit the fasta, we've seen all annotations
997 if line.startswith("##FASTA"): break
998 if line.startswith("##sequence-region"):
999 _comment, contig, start, end = line.split(" ")
1000 start, end = int(start), int(end)
1001 if contig not in annotations:
1002 annotations[contig] = {"start": start, "end": end, "length": 1 + end - start, "loci": OrderedDict()}
1003 comment_contigs.add(contig)
1004 continue
1005 # Skip over all other comments
1006 if line.startswith("#"): continue
1007 # Skip over empty lines
1008 if line == "": continue
1010 # Parse standardized annotation fields
1011 line = line.replace("\"", "")
1012 line_split = line.split("\t")
1013 if len(line_split) < 9:
1014 msg = f"GFF record does not contain 9 fields: {line}"
1015 logging.error(msg)
1016 raise Exception(msg)
1017 contig, _source, feature, start, end, _score, strand, _frame = [line_split[i] for i in range(0,8)]
1018 start, end = int(start), int(end)
1019 if feature == "region" or feature == "databank_entry":
1020 if contig in comment_contigs:
1021 continue
1022 elif contig in annotations:
1023 msg = f"Duplicate contig ID found: {contig}"
1024 logging.error(msg)
1025 raise Exception(msg)
1026 annotations[contig] = {"start": start, "end": end, "length": 1 + end - start, "loci": OrderedDict()}
1027 continue
1028 elif regex != None and not re.match(regex, line, re.IGNORECASE):
1029 continue
1031 attributes = line_split[8]
1032 attributes_dict = {a.split("=")[0].replace(" ", ""):a.split("=")[1] for a in attributes.split(";") if "=" in a}
1033 locus = f"{attributes_dict['ID']}"
1034 if "gene" in attributes_dict:
1035 gene = attributes_dict["gene"]
1036 if gene not in gene_counts:
1037 gene_counts[gene] = {"count": 0, "current": 1}
1038 gene_counts[gene]["count"] += 1
1040 # Add sample prefix to help duplicate IDs later
1041 if not locus.startswith(sample):
1042 locus = f"{sample}_{locus}"
1044 # Check for duplicate loci, how does this happen in NCBI annotations?
1045 if locus not in locus_counts:
1046 locus_counts[locus] = 1
1047 else:
1048 locus_counts[locus] += 1
1049 dup_i = locus_counts[locus]
1050 locus = f"{locus}.{dup_i}"
1051 logging.debug(f"Duplicate locus ID found, flagging as: {locus}")
1053 if contig not in contigs:
1054 msg = f"Contig {contig} in {sample} annotations is not present in the sequences."
1055 logging.error(msg)
1056 raise Exception(msg)
1057 sequence = contigs[contig][start - 1:end]
1058 if strand == "-":
1059 sequence = reverse_complement(sequence)
1061 data = OrderedDict({
1062 "sample" : sample,
1063 "contig" : contig,
1064 "locus" : locus,
1065 "feature" : feature,
1066 "start" : start,
1067 "end" : end,
1068 "length" : 1 + end - start,
1069 "strand" : strand,
1070 "upstream" : "",
1071 "downstream" : "",
1072 "attributes" : attributes,
1073 "sequence_id" : f"{locus}",
1074 "sequence" : sequence,
1075 })
1076 logging.debug(f"\tcontig={contig}, locus={locus}")
1077 annotations[contig]["loci"][locus] = data
1078 locus_to_contig[locus] = contig
1080 # Check for duplicates, rename the first duplicate loci
1081 logging.info(f"Checking for duplicate locus IDs.")
1082 for locus,count in locus_counts.items():
1083 if count <= 1: continue
1084 contig = locus_to_contig[locus]
1085 data = annotations[contig]["loci"][locus]
1086 new_locus = f"{locus}.1"
1087 logging.debug(f"Duplicate locus ID found, flagging as: {new_locus}")
1088 data["locus"], data["sequence_id"] = new_locus, new_locus
1089 annotations[contig]["loci"] = OrderedDict({
1090 new_locus if k == locus else k:v
1091 for k,v in annotations[contig]["loci"].items()
1092 })
1094 # -------------------------------------------------------------------------
1095 # Extract unannotated regions
1097 if regex != None:
1098 logging.info("Skipping extraction of unannotate regions due to regex.")
1099 else:
1100 logging.info(f"Extracting unannotated regions.")
1101 for contig,contig_data in annotations.items():
1102 if contig not in contigs:
1103 msg = f"Contig {contig} in {sample} annotations is not present in the sequences."
1104 logging.error(msg)
1105 raise Exception(msg)
1106 contig_sequence = contigs[contig]
1107 num_annotations = len([k for k,v in contig_data["loci"].items() if v["feature"] != "region"])
1108 contig_start = contig_data["start"]
1109 contig_end = contig_data["end"]
1111 # Contig minimum length check
1112 contig_len = (1 + contig_end - contig_start)
1113 if contig_len < min_len: continue
1115 logging.debug(f"\tcontig={contig}")
1117 # If there were no annotations on the contig, extract entire contig as sequence
1118 if num_annotations == 0:
1119 logging.debug(f"\t\tlocus={contig}")
1120 locus = f"{contig}"
1121 if not locus.startswith(sample):
1122 locus = f"{sample}_{locus}"
1123 l_data = {
1124 "sample" : sample,
1125 "contig" : contig,
1126 "locus" : locus,
1127 "feature" : "unannotated",
1128 "start" : contig_start,
1129 "end" : contig_end,
1130 "length" : 1 + contig_end - contig_start,
1131 "strand" : "+",
1132 "upstream" : f"{contig}_TERMINAL",
1133 "downstream" : f"{contig}_TERMINAL",
1134 "attributes" : "",
1135 "sequence_id" : locus,
1136 "sequence" : contig_sequence,
1137 }
1138 annotations[contig]["loci"][contig] = l_data
1139 logging.debug(f"\t\t\tfull contig unannotated, locus={locus}")
1140 continue
1142 contig_annotations = list(contig_data["loci"].keys())
1144 for i,locus in enumerate(contig_annotations):
1145 logging.debug(f"\t\tlocus={locus}")
1146 l_data = contig_data["loci"][locus]
1147 l_start, l_end = l_data["start"], l_data["end"]
1149 # Make a template for an unannotated region
1150 l_data = {
1151 "sample" : sample,
1152 "contig" : contig,
1153 "locus" : None,
1154 "feature" : "unannotated",
1155 "start" : None,
1156 "end" : None,
1157 "length" : 0,
1158 "strand" : "+",
1159 "upstream" : "",
1160 "downstream" : "",
1161 "attributes" : "",
1162 "sequence_id" : None,
1163 "sequence" : None,
1164 }
1165 # Find the inter-genic regions and upstream/downstream loci
1166 start, end, sequence, upstream, downstream = None, None, None, None, None
1168 # Case 1. Unannotated at the start of the contig
1169 if i == 0 and l_start != contig_start:
1170 start = 1
1171 end = l_start - 1
1172 length = 1 + end - start
1173 if length >= min_len and (max_len == None or length <= max_len):
1174 upstream = f"{contig}_TERMINAL"
1175 downstream = locus
1176 # Base strand on the downstream loci
1177 strand = annotations[contig]["loci"][downstream]["strand"]
1178 sequence_id = f"{upstream}{delim}{downstream}"
1179 sequence = contig_sequence[start - 1:end]
1180 # If the downstream is reverse, treat unannotated as reversed
1181 if strand == "-":
1182 sequence = reverse_complement(sequence)
1183 l_data_start = copy.deepcopy(l_data)
1184 for k,v in zip(["start", "end", "length", "sequence_id", "sequence", "locus", "strand"], [start, end, length, sequence_id, sequence, sequence_id, strand]):
1185 l_data_start[k] = v
1186 l_data_start["upstream"], l_data_start["downstream"] = upstream, downstream
1187 logging.debug(f"\t\t\tunannotated at start, contig={contig}, sequence_id={sequence_id}, start: {start}, end: {end}")
1188 annotations[contig]["loci"][sequence_id] = l_data_start
1190 # Update upstream for annotation
1191 annotations[contig]["loci"][locus]["upstream"] = sequence_id
1193 # Case 2. Unannotated at the end of the contig
1194 if i == (num_annotations - 1) and l_end != contig_end:
1195 start = l_end + 1
1196 end = contig_end
1197 length = 1 + end - start
1198 if length >= min_len and (max_len == None or length <= max_len):
1199 upstream = locus
1200 # base strand on the upstream loci
1201 strand = annotations[contig]["loci"][upstream]["strand"]
1202 downstream = f"{contig}_TERMINAL"
1203 sequence_id = f"{upstream}{delim}{downstream}"
1204 sequence = contig_sequence[start - 1:end]
1205 # If the upstream is reversed, treat unannotated as reversed
1206 if strand == "-":
1207 sequence = reverse_complement(sequence)
1208 l_data_end = copy.deepcopy(l_data)
1209 for k,v in zip(["start", "end", "length", "sequence_id", "sequence", "locus", "strand"], [start, end, length, sequence_id, sequence, sequence_id, strand]):
1210 l_data_end[k] = v
1211 l_data_end["upstream"], l_data_end["downstream"] = upstream, downstream
1212 logging.debug(f"\t\t\tunannotated at end, contig={contig}, sequence_id={sequence_id}, start: {start}, end: {end}")
1213 annotations[contig]["loci"][sequence_id] = l_data_end
1215 # Update downstream for annotation
1216 annotations[contig]["loci"][locus]["downstream"] = sequence_id
1218 # Case 3. Unannotated in between annotations
1219 if num_annotations > 1 and i != (num_annotations - 1):
1221 upstream = locus
1222 downstream = contig_annotations[i+1]
1223 start = l_end + 1
1224 end = contig_data["loci"][downstream]["start"] - 1
1225 length = 1 + end - start
1226 # Set upstream downstream based on order in GFF
1227 upstream_strand = annotations[contig]["loci"][upstream]["strand"]
1228 downstream_strand = annotations[contig]["loci"][downstream]["strand"]
1230 # Check that the region is long enough
1231 if length >= min_len and (max_len == None or length <= max_len):
1232 sequence_id = f"{upstream}{delim}{downstream}"
1233 sequence = contig_sequence[start - 1:end]
1234 # Identify strand and orientation
1235 if upstream_strand == "-" and downstream_strand == "-":
1236 strand = "-"
1237 sequence = reverse_complement(sequence)
1238 else:
1239 strand = "+"
1240 l_data_middle = copy.deepcopy(l_data)
1241 for k,v in zip(["start", "end", "length", "sequence_id", "sequence", "locus", "strand"], [start, end, length, sequence_id, sequence, sequence_id, strand]):
1242 l_data_middle[k] = v
1243 l_data_middle["upstream"], l_data_middle["downstream"] = upstream, downstream
1244 logging.debug(f"\t\t\tunannotated in middle, contig={contig}, sequence_id={sequence_id}, start: {start}, end: {end}")
1245 annotations[contig]["loci"][sequence_id] = l_data_middle
1247 # -------------------------------------------------------------------------
1248 # Order and Filter
1250 logging.info(f"Ordering records by contig and coordinate.")
1251 contig_order = list(annotations.keys())
1252 for contig in contig_order:
1253 loci = annotations[contig]["loci"]
1254 # sort by start and end position
1255 loci_sorted = OrderedDict(sorted(loci.items(), key=lambda item: [item[1]["start"], item[1]["end"]]) )
1256 annotations[contig]["loci"] = OrderedDict()
1257 for locus,locus_data in loci_sorted.items():
1258 # final checks, minimum length and not entirely ambiguous characters
1259 non_ambig = list(set([n for n in locus_data["sequence"] if n in NUCLEOTIDES]))
1260 length = locus_data["length"]
1261 if (length >= min_len) and (len(non_ambig) > 0):
1262 if max_len == None or length <= max_len:
1263 annotations[contig]["loci"][locus] = locus_data
1266 # -------------------------------------------------------------------------
1267 # Upstream/downtream loci
1269 for contig in annotations:
1270 contig_data = annotations[contig]
1271 c_start, c_end = contig_data["start"], contig_data["end"]
1272 loci = list(contig_data["loci"].keys())
1273 for i,locus in enumerate(contig_data["loci"]):
1274 upstream, downstream = "", ""
1275 locus_data = contig_data["loci"][locus]
1276 l_start, l_end = locus_data["start"], locus_data["end"]
1277 # Annotation at very start
1278 if i > 0:
1279 upstream = loci[i-1]
1280 elif l_start == c_start:
1281 upstream = f"{contig}_TERMINAL"
1282 # Annotation at very end
1283 if i < (len(loci) - 1):
1284 downstream = loci[i+1]
1285 elif l_end == c_end:
1286 downstream = f"{contig}_TERMINAL"
1287 annotations[contig]["loci"][locus]["upstream"] = upstream
1288 annotations[contig]["loci"][locus]["downstream"] = downstream
1290 # -------------------------------------------------------------------------
1291 # Write Output
1293 tsv_path = os.path.join(output_dir, prefix + ".tsv")
1294 logging.info(f"Writing output tsv: {tsv_path}")
1295 with open(tsv_path, 'w') as outfile:
1296 header = None
1297 for contig,contig_data in annotations.items():
1298 for locus,locus_data in contig_data["loci"].items():
1299 if locus_data["feature"] == "region": continue
1300 if header == None:
1301 header = list(locus_data.keys())
1302 outfile.write("\t".join(header) + "\n")
1303 row = [str(v) for v in locus_data.values()]
1304 # Restore commas from there %2C encoding
1305 line = "\t".join(row).replace("%2C", ",")
1306 outfile.write(line + "\n")
1308 return tsv_path
1311def collect(
1312 tsv: list = None,
1313 tsv_paths: str = None,
1314 outdir : str = ".",
1315 prefix : str = None,
1316 args: str = None,
1317 ) -> str:
1318 """
1319 Collect sequences from multiple samples into one file.
1321 Takes as input multiple TSV files from the extract subcommand, which can
1322 be supplied as either space separate paths, or a text file containing paths.
1323 Duplicate sequence IDs will be identified and given the suffix '.#'.
1324 Outputs concatenated FASTA and TSV files.
1326 >>> collect(tsv=["sample1.tsv", "sample2.tsv"])
1327 >>> collect(tsv_paths='paths.txt')
1329 :param tsv: List of TSV file paths from the output of extract.
1330 :param tsv_paths: TXT file with each line containing a TSV file path.
1331 :param output: Path to the output directory.
1332 :param prefix: Output file prefix.
1333 :param args: Additional arguments [not implemented]
1335 :return: Tuple of output FASTA and TSV paths.
1336 """
1338 from Bio import SeqIO
1340 tsv_paths, tsv_txt_path, output_dir = tsv, tsv_paths, outdir
1341 prefix = f"{prefix}." if prefix != None else ""
1343 # Check output directory
1344 check_output_dir(output_dir)
1346 # TSV file paths
1347 tsv_file_paths = []
1348 if tsv_txt_path != None:
1349 logging.info(f"Reading tsv paths from file: {tsv_txt_path}")
1350 with open(tsv_txt_path) as infile:
1351 for line in infile:
1352 file_path = [l.strip() for l in line.split("\t")][0]
1353 tsv_file_paths.append(file_path.strip())
1354 if tsv_paths != None:
1355 for file_path in tsv_paths:
1356 tsv_file_paths.append(file_path)
1358 logging.info(f"Checking for duplicate samples and sequence IDs.")
1359 all_samples = []
1360 sequence_id_counts = {}
1361 for file_path in tqdm(tsv_file_paths):
1362 with open(file_path) as infile:
1363 header = [l.strip() for l in infile.readline().split("\t")]
1364 for i,line in enumerate(infile):
1365 row = [l.strip() for l in line.split("\t")]
1366 data = {k:v for k,v in zip(header, row)}
1368 # Check for duplicate samples
1369 if i == 0:
1370 sample = data["sample"]
1371 if sample in all_samples:
1372 msg = f"Duplicate sample ID found: {sample}"
1373 logging.error(msg)
1374 raise Exception(msg)
1375 all_samples.append(sample)
1377 # Check for duplicate IDs
1378 sequence_id = data["sequence_id"]
1379 if sequence_id not in sequence_id_counts:
1380 sequence_id_counts[sequence_id] = {"count": 0, "current": 1}
1381 sequence_id_counts[sequence_id]["count"] += 1
1383 output_fasta = os.path.join(outdir, f"{prefix}sequences.fasta")
1384 logging.info(f"Writing output fasta file to: {output_fasta}")
1385 output_tsv = os.path.join(outdir, f"{prefix}regions.tsv")
1386 logging.info(f"Writing output tsv file to: {output_tsv}")
1388 fasta_outfile = open(output_fasta, 'w')
1389 tsv_outfile = open(output_tsv, "w")
1390 header = None
1392 for file_path in tqdm(tsv_file_paths):
1393 with open(file_path) as infile:
1394 header_line = infile.readline()
1395 if header == None:
1396 header = [l.strip() for l in header_line.split("\t")]
1397 tsv_outfile.write(header_line)
1398 for line in infile:
1399 row = [l.strip() for l in line.split("\t")]
1400 data = {k:v for k,v in zip(header, row)}
1402 # Handle duplicate sequence IDs with a .# suffix
1403 sequence_id = data["sequence_id"]
1404 count = sequence_id_counts[sequence_id]["count"]
1405 if count > 1:
1406 dup_i = sequence_id_counts[sequence_id]["current"]
1407 sequence_id_counts[sequence_id]["current"] += 1
1408 data["sequence_id"] = sequence_id = f"{sequence_id}.{dup_i}"
1409 logging.debug(f"Duplicate sequence ID found in TSV, flagging as: {sequence_id}")
1411 line = "\t".join([str(data[col]) for col in header])
1412 tsv_outfile.write(line + "\n")
1414 sequence = data["sequence"]
1415 line = f">{sequence_id}\n{sequence}"
1416 fasta_outfile.write(line + "\n")
1418 fasta_outfile.close()
1419 tsv_outfile.close()
1421 return (output_fasta, output_tsv)
1424def cluster(
1425 fasta: str,
1426 outdir: str = ".",
1427 prefix: str = None,
1428 threads: int = 1,
1429 memory: str = "1G",
1430 tmp: str = "tmp",
1431 clean: bool = True,
1432 args: str = CLUSTER_ARGS,
1433 ):
1434 """
1435 Cluster nucleotide sequences with mmseqs.
1437 Takes as input a FASTA file of sequences for clustering from collect.
1438 Calls MMSeqs2 to cluster sequences and identify a representative sequence.
1439 Outputs a TSV table of sequence clusters and a FASTA of representative sequences.
1441 Note: The default kmer size (15) requires at least 9G of memory. To use less
1442 memory, please set the kmer size to '-k 13'.
1444 >>> cluster(fasta='sequences.fasta')
1445 >>> cluster(fasta='sequences.fasta', threads=2, memory='2G', args='-k 13 --min-seq-id 0.90 -c 0.90')
1446 >>> cluster(fasta='sequences.fasta', threads=4, args='-k 13 --min-seq-id 0.90 -c 0.90')
1448 :param fasta: Path fo FASTA sequences.
1449 :param outdir: Output directory.
1450 :param prefix: Prefix for output files.
1451 :param threads: CPU threads for MMSeqs2.
1452 :param memory: Memory for MMSeqs2.
1453 :param tmp: Path to a temporary directory.
1454 :param clean: True if intermediate files should be cleaned up.
1455 :param args: Additional parameters for MMSeqs2 cluster command.
1456 """
1458 from Bio import SeqIO
1460 fasta_path, tmp_dir, output_dir = fasta, tmp, outdir
1461 prefix = f"{prefix}." if prefix != None else ""
1463 # Wrangle the output directory
1464 check_output_dir(output_dir)
1465 check_output_dir(tmp_dir)
1467 args = args if args != None else ""
1469 # fix memory formatting (ex. '6 GB' -> '6G')
1470 memory = memory.replace(" ", "").replace("B", "")
1472 # 1. Cluster Sequences
1473 seq_db = os.path.join(output_dir, f"{prefix}seqDB")
1474 clust_db = os.path.join(output_dir, f"{prefix}clustDB")
1475 tsv_path = os.path.join(output_dir, f"{prefix}clusters.tsv")
1476 run_cmd(f"mmseqs createdb {fasta_path} {seq_db}")
1477 try:
1478 run_cmd(f"mmseqs cluster {seq_db} {clust_db} {tmp_dir} --threads {threads} --split-memory-limit {memory} {args}")
1479 except Exception as e:
1480 if "Segmentation fault" in f"{e}":
1481 logging.error(f"Segmentation fault. Try changing your threads and/or memory. Memory consumption is {threads} x {memory}.")
1482 raise Exception(e)
1483 run_cmd(f"mmseqs createtsv {seq_db} {seq_db} {clust_db} {tsv_path} --threads {threads} --full-header")
1485 # Sort and remove quotations in clusters
1486 with open(tsv_path) as infile:
1487 lines = sorted([l.strip().replace("\"", "") for l in infile.readlines()])
1488 with open(tsv_path, 'w') as outfile:
1489 outfile.write("\n".join(lines) + "\n")
1491 # 2. Identify representative sequences
1492 rep_db = os.path.join(output_dir, f"{prefix}repDB")
1493 rep_fasta_path = os.path.join(output_dir, f"{prefix}representative.fasta")
1494 run_cmd(f"mmseqs result2repseq {seq_db} {clust_db} {rep_db} --threads {threads}")
1495 run_cmd(f"mmseqs result2flat {seq_db} {seq_db} {rep_db} {rep_fasta_path} --use-fasta-header")
1497 # Sort sequences, and remove trailing whitespace
1498 sequences = {}
1499 for record in SeqIO.parse(rep_fasta_path, "fasta"):
1500 locus = record.id.strip()
1501 seq = record.seq.strip()
1502 sequences[locus] = seq
1503 with open(rep_fasta_path, 'w') as outfile:
1504 for locus in sorted(list(sequences.keys())):
1505 seq = sequences[locus]
1506 outfile.write(f">{locus}\n{seq}\n")
1508 # Sort clusters
1509 with open(tsv_path) as infile:
1510 lines = sorted([line.strip() for line in infile])
1511 with open(tsv_path, 'w') as outfile:
1512 outfile.write("\n".join(lines))
1514 # Cleanup
1515 if clean == True:
1516 for file_name in os.listdir(output_dir):
1517 file_path = os.path.join(output_dir, file_name)
1518 db_prefix = os.path.join(output_dir, prefix)
1519 if (
1520 file_path.startswith(f"{db_prefix}seqDB") or
1521 file_path.startswith(f"{db_prefix}clustDB") or
1522 file_path.startswith(f"{db_prefix}repDB")
1523 ):
1524 logging.info(f"Cleaning up file: {file_path}")
1525 os.remove(file_path)
1527 return(tsv_path, rep_fasta_path)
1530def defrag(
1531 clusters: str,
1532 representative: str,
1533 outdir: str = ".",
1534 prefix: str = None,
1535 tmp: str = "tmp",
1536 threads: str = 1,
1537 memory: str = "2G",
1538 clean: str = True,
1539 args: str = DEFRAG_ARGS,
1540 ) -> OrderedDict:
1541 """
1542 Defrag clusters by associating fragments with their parent cluster.
1544 Takes as input the TSV clusters and FASTA representatives from the cluster subcommand.
1545 Outputs a new cluster table and representative sequences fasta.
1547 This is a modification of ppanggolin's refine_clustering function:
1548 https://github.com/labgem/PPanGGOLiN/blob/2.2.0/ppanggolin/cluster/cluster.py#L317
1550 >>> defrag(clusters='clusters.tsv', representative='representative.fasta', prefix="defrag")
1551 >>> defrag(clusters='clusters.tsv', representative='representative.fasta', threads=2, memory='2G', args="-k 13 --min-seq-id 0.90 -c 0.90 --cov-mode 1")
1553 :param clusters: TSV file of clusters.
1554 :param representative: FASTA file of representative sequences from cluster (mmseqs)
1555 :param outdir: Path to the output directory.
1556 :param prefix: Prefix for output files.
1557 :param tmp: Path to a temporary directory.
1558 :param threads: CPU threads for mmseqs.
1559 :param memory: Memory for mmseqs.
1560 :param clean: True if intermediate DB files should be cleaned up.
1561 :param args: Additional arguments for `mmseqs search`.
1564 :return: Ordered Dictionary of new defragmented clusters
1565 """
1567 from Bio import SeqIO
1569 clusters_path, representative_path, output_dir, tmp_dir = clusters, representative, outdir, tmp
1570 prefix = f"{prefix}." if prefix != None else ""
1572 # Check output directory
1573 check_output_dir(output_dir)
1574 check_output_dir(tmp_dir)
1576 args = args if args != None else ""
1578 # fix memory formatting (ex. '6 GB' -> '6G')
1579 memory = memory.replace(" ", "").replace("B", "")
1581 # -------------------------------------------------------------------------
1582 # Align representative sequences against each other
1584 logging.info("Aligning representative sequences.")
1586 seq_db = os.path.join(output_dir, f"{prefix}seqDB")
1587 aln_db = os.path.join(output_dir, f"{prefix}alnDB")
1588 tsv_path = os.path.join(output_dir, f"{prefix}align.tsv")
1589 run_cmd(f"mmseqs createdb {representative_path} {seq_db}")
1590 run_cmd(f"mmseqs search {seq_db} {seq_db} {aln_db} {tmp_dir} --threads {threads} --split-memory-limit {memory} --search-type 3 {args}")
1591 columns="query,target,fident,alnlen,mismatch,gapopen,qstart,qend,qlen,tstart,tend,tlen,evalue,bits"
1592 run_cmd(f"mmseqs convertalis {seq_db} {seq_db} {aln_db} {tsv_path} --search-type 3 --format-output {columns}")
1594 # Sort align rep stats for reproducibility
1595 with open(tsv_path, 'r') as infile:
1596 lines = sorted([l.strip() for l in infile.readlines()])
1597 with open(tsv_path, 'w') as outfile:
1598 header = "\t".join(columns.split(","))
1599 outfile.write(header + "\n")
1600 outfile.write("\n".join(lines) + "\n")
1603 # -------------------------------------------------------------------------
1604 logging.info(f"Reading clusters: {clusters_path}")
1605 loci = OrderedDict()
1606 clusters = OrderedDict()
1607 with open(clusters_path) as infile:
1608 lines = infile.readlines()
1609 for line in tqdm(lines):
1610 cluster, locus = [l.strip() for l in line.split("\t")]
1611 loci[locus] = cluster
1612 if cluster not in clusters:
1613 clusters[cluster] = []
1614 if locus not in clusters[cluster]:
1615 clusters[cluster].append(locus)
1617 # -------------------------------------------------------------------------
1618 logging.info(f"Reading representative: {representative_path}")
1619 representative = OrderedDict()
1620 for record in SeqIO.parse(representative_path, "fasta"):
1621 representative[record.id] = record.seq
1623 # -------------------------------------------------------------------------
1624 # Similarity Graph
1626 # Create a graph of cluster relationships, based on the alignment of their
1627 # representative sequences. The edges between clusters will represent the
1628 # pairwise alignment score (bits).
1630 # This is a modification of ppanggolin's refine_clustering function:
1631 # https://github.com/labgem/PPanGGOLiN/blob/2.2.0/ppanggolin/cluster/cluster.py#L317
1632 # The major difference is that the original function imposes a constraint
1633 # that the fragmented cluster cannot contain more loci than the new parent.
1634 # This function does not use the number of loci, which allows a cluster
1635 # with many small fragments to be reassigned to a longer parent that might
1636 # be represented by only one intact sequence. This function also prefers
1637 # a OrderedDict over a formal graph object, so we can avoid the networkx
1638 # dependency.
1640 logging.info(f"Creating similarity graph from alignment: {tsv_path}")
1642 graph = OrderedDict()
1644 with open(tsv_path) as infile:
1645 header = infile.readline().split()
1646 cluster_i, locus_i, qlen_i, tlen_i, bits_i = [header.index(c) for c in ["query", "target", "qlen", "tlen", "bits"]]
1647 lines = infile.readlines()
1648 for line in tqdm(lines):
1649 row = [r.strip() for r in line.replace('"', '').split()]
1650 if row == []: continue
1651 query, target, qlen, tlen, bits = [row[i] for i in [cluster_i, locus_i, qlen_i, tlen_i, bits_i]]
1652 if query != target:
1653 if query not in graph:
1654 graph[query] = OrderedDict()
1655 if target not in graph:
1656 graph[target] = OrderedDict()
1657 graph[query][target] = {"score": float(bits), "length": int(qlen)}
1658 graph[target][query] = {"score": float(bits), "length": int(tlen)}
1660 # -------------------------------------------------------------------------
1661 logging.info(f"Identifying fragmented loci.")
1663 # Reassign fragmented loci to their new 'parent' cluster which must:
1664 # 1. Be longer (length)
1665 # 2. Have a higher score.
1667 defrag_clusters = OrderedDict({c:{
1668 "loci": clusters[c],
1669 "sequence": representative[c],
1670 "fragments": []}
1671 for c in clusters
1672 })
1674 reassigned = {}
1676 nodes = list(graph.keys())
1678 for node in tqdm(nodes):
1679 # Get the current parent node
1680 candidate_node = None
1681 candidate_score = 0
1682 # Iterate through the candidate targets (clusters it could be aligned against)
1683 logging.debug(f"query: {node}")
1684 for target in graph[node]:
1685 ndata = graph[node][target]
1686 tdata = graph[target][node]
1687 nlen, tlen, tscore = ndata["length"], tdata["length"], tdata["score"]
1688 # Compare lengths and scores
1689 logging.debug(f"\ttarget: {target}, tlen: {tlen}, nlen: {nlen}, tscore: {tscore}, cscore: {candidate_score}")
1690 if tlen > nlen and candidate_score < tscore:
1691 candidate_node = target
1692 candidate_score = tscore
1694 # Check candidate
1695 if candidate_node is not None:
1696 # candidate node might have also been a fragment that got reassigned
1697 while candidate_node not in defrag_clusters and candidate_node in reassigned:
1698 new_candidate_node = reassigned[candidate_node]
1699 logging.debug(f"Following fragments: {node}-->{candidate_node}-->{new_candidate_node}")
1700 candidate_node = new_candidate_node
1701 for locus in clusters[node]:
1702 defrag_clusters[candidate_node]["loci"].append(locus)
1703 defrag_clusters[candidate_node]["fragments"].append(locus)
1704 del defrag_clusters[node]
1705 reassigned[node] = candidate_node
1707 # Sort order for reproducibility
1708 defrag_cluster_order = sorted(list(defrag_clusters.keys()))
1710 defrag_clusters_path = os.path.join(output_dir, f"{prefix}clusters.tsv")
1711 defrag_rep_path = os.path.join(output_dir, f"{prefix}representative.fasta")
1713 with open(defrag_clusters_path, 'w') as clust_file:
1714 with open(defrag_rep_path, 'w') as seq_file:
1715 for cluster in defrag_cluster_order:
1716 info = defrag_clusters[cluster]
1717 # Write sequence
1718 seq_file.write(f">{cluster}\n{info['sequence']}" + "\n")
1719 # Write cluster loci, we sort for reproducibility
1720 for locus in sorted(info["loci"]):
1721 fragment = "F" if locus in info["fragments"] else ""
1722 line = f"{cluster}\t{locus}\t{fragment}"
1723 clust_file.write(line + '\n')
1725 # Cleanup
1726 if clean == True:
1727 for file_name in os.listdir(output_dir):
1728 file_path = os.path.join(output_dir, file_name)
1729 db_prefix = os.path.join(output_dir, prefix)
1730 if (
1731 file_path.startswith(f"{db_prefix}seqDB") or
1732 file_path.startswith(f"{db_prefix}alnDB")
1733 ):
1734 logging.info(f"Cleaning up file: {file_path}")
1735 os.remove(file_path)
1737 logging.info(f"IMPORTANT!\n{PPANGGOLIN_NOTICE}")
1739 return (defrag_clusters_path, defrag_rep_path)
1742def summarize(
1743 clusters: str,
1744 regions: str,
1745 outdir: str = ".",
1746 product_clean: dict = {
1747 " ": "_",
1748 "putative" : "",
1749 "hypothetical": "",
1750 "(": "",
1751 ")": "",
1752 ",": "",
1753 },
1754 max_product_len: int = 50,
1755 min_samples: int = 1,
1756 threshold: float = 0.5,
1757 prefix: str = None,
1758 args: str = None,
1759 ):
1760 """
1761 Summarize clusters according to their annotations.
1763 Takes as input the clusters TSV from either the cluster or defrag subcommand,
1764 and the TSV table of annotations from the collect subcommand.
1765 Outputs a TSV table of clusters and their summarized annotations.
1767 >>> summarize(clusters="clusters.tsv", sequences="sequences.tsv", prefix="summarize")
1769 :param clusters: TSV file of clusters from cluster or defrag (mmseqs).
1770 :param sequences: TSV file of sequences from collect.
1771 :param outdir: Output directory.
1772 :param product: Remove these words from the product description when it is the identifier.
1773 :param max_product_len: Truncate product to this length if it's being used as an identifier.
1774 :param prefix: Prefix for output files.
1775 :param args: Str of additional arguments for mmseqs search.
1777 :return: Ordered Dictionary of summarized clusters
1778 """
1780 import networkx
1781 from networkx.exception import NetworkXNoCycle
1783 sequences_path, clusters_path, output_dir = regions, clusters, outdir
1784 prefix = f"{prefix}." if prefix != None else ""
1785 check_output_dir(output_dir)
1787 args = args if args != None else ""
1789 # Type conversions as fallback
1790 threshold = float(threshold)
1791 max_product_len, min_samples = int(max_product_len), int(min_samples)
1793 # -------------------------------------------------------------------------
1794 # Read Sequence records
1796 logging.info(f"Reading sequence regions: {sequences_path}")
1797 sequences = OrderedDict()
1798 all_samples = []
1799 with open(sequences_path) as infile:
1800 header = [line.strip() for line in infile.readline().split("\t")]
1801 lines = infile.readlines()
1802 for line in tqdm(lines):
1803 row = [l.strip() for l in line.split("\t")]
1804 data = {k:v for k,v in zip(header, row)}
1805 sample, sequence_id = data["sample"], data["sequence_id"]
1806 if sample not in all_samples:
1807 all_samples.append(sample)
1808 if sequence_id not in sequences:
1809 sequences[sequence_id] = data
1810 # At this point, we should not have duplicate sequence IDs
1811 # That would have been handled in the extract/collect command
1812 else:
1813 msg = f"Duplicate sequence ID found: {sequence_id}"
1814 logging.error(msg)
1815 raise Exception(msg)
1817 # -------------------------------------------------------------------------
1818 # Read Clusters
1820 logging.info(f"Reading clusters: {clusters_path}")
1821 seen = set()
1822 clusters = OrderedDict()
1823 representative_to_cluster = {}
1824 i = 0
1826 with open(clusters_path) as infile:
1827 lines = infile.readlines()
1828 for line in tqdm(lines):
1830 # Extract the 3 columns: cluster representative, sequence Id, and (optional) fragment
1831 row = [l.strip() for l in line.split("\t")]
1832 representative, sequence_id = row[0], row[1]
1833 fragment = True if len(row) > 2 and row[2] == "F" else False
1834 if representative not in seen:
1835 i += 1
1836 cluster = f"Cluster_{i}"
1837 seen.add(representative)
1838 representative_to_cluster[representative] = cluster
1840 if sequence_id not in sequences:
1841 msg = f"Sequence is present in clusters but not in regions: {sequence_id}"
1842 logging.error(msg)
1843 raise Exception(msg)
1844 sequences[sequence_id]["cluster"] = cluster
1846 # Start the cluster data
1847 if cluster not in clusters:
1848 clusters[cluster] = OrderedDict({
1849 "cluster": cluster,
1850 "representative": representative,
1851 "sequences": OrderedDict()
1852 })
1853 if sequence_id not in clusters[cluster]["sequences"]:
1854 clusters[cluster]["sequences"][sequence_id] = sequences[sequence_id]
1855 clusters[cluster]["sequences"][sequence_id]["fragment"] = fragment
1856 else:
1857 msg = f"Duplicate cluster sequence ID found: {sequence_id}"
1858 logging.error(msg)
1859 raise Exception(msg)
1861 logging.info(f"Found {len(clusters)} clusters.")
1863 # -------------------------------------------------------------------------
1864 # Summarize
1865 # -------------------------------------------------------------------------
1867 logging.info(f"Summarizing clusters.")
1868 summarized = OrderedDict()
1869 gene_counts = {}
1870 product_counts = {}
1872 for cluster,cluster_data in tqdm(clusters.items()):
1874 # Identify number of samples
1875 cluster_samples = set()
1876 for s in cluster_data["sequences"].values():
1877 cluster_samples.add(s["sample"])
1878 num_samples = len(list(cluster_samples))
1879 # If user requested a minimum number of samples
1880 if num_samples < min_samples:
1881 continue
1882 samples_non_fragmented = set()
1883 sequences_per_sample = {}
1884 summarized[cluster] = OrderedDict({
1885 "cluster": cluster,
1886 "cluster_id": cluster,
1887 "synteny": "",
1888 "synteny_pos": "",
1889 "num_samples": num_samples,
1890 "num_samples_non_fragmented": 0,
1891 "num_sequences": 0,
1892 "mean_sequences_per_sample": 0,
1893 "representative": cluster_data["representative"],
1894 "upstream": "",
1895 "downstream": "",
1896 "upstream_alt": "",
1897 "downstream_alt": ""
1898 }
1899 )
1900 cluster_sequences = cluster_data["sequences"]
1901 features = Counter()
1902 genes = Counter()
1903 products = Counter()
1904 names = Counter()
1905 dbxrefs = Counter()
1906 strands = Counter()
1907 contexts = Counter()
1908 contexts_uniq = Counter()
1910 # -------------------------------------------------------------------------
1911 # Collect sequence attributes
1913 for sequence_id,seq_data in cluster_sequences.items():
1915 summarized[cluster]["num_sequences"] += 1
1917 sample = seq_data["sample"]
1918 if sample not in sequences_per_sample:
1919 sequences_per_sample[sample] = 0
1920 sequences_per_sample[sample] += 1
1922 # Ignore annotations from fragments starting here
1923 if seq_data["fragment"] == True: continue
1925 samples_non_fragmented.add(sample)
1926 features[seq_data["feature"]] += 1
1927 strands[seq_data["strand"]] += 1
1929 # collect the annotation attributes
1930 attributes = {a.split("=")[0]:a.split("=")[1] for a in seq_data["attributes"].split(";") if "=" in a}
1931 if "gene" in attributes:
1932 genes[attributes["gene"]] += 1
1933 if "product" in attributes:
1934 products[attributes["product"]] += 1
1935 if "Name" in attributes:
1936 names[attributes["Name"]] += 1
1937 if "Dbxref" in attributes:
1938 dbxrefs[attributes["Dbxref"]] += 1
1940 # Get the upstream/downstream locus IDs
1941 upstream, downstream = seq_data["upstream"], seq_data["downstream"]
1943 # Simply upstream/downstream loci if they represent the start/end of a contig
1944 upstream = "TERMINAL" if upstream.endswith("_TERMINAL") and "__" not in upstream else upstream
1945 downstream = "TERMINAL" if downstream.endswith("_TERMINAL") and "__" not in downstream else downstream
1947 # Get the upstream/downstream cluster IDs
1948 # There is a possibility that the upstream/downstream sequences
1949 # didn't actually get classified into a cluster
1950 if upstream in sequences and "cluster" in sequences[upstream]:
1951 upstream = sequences[upstream]["cluster"]
1952 if downstream in sequences and "cluster" in sequences[downstream]:
1953 downstream = sequences[downstream]["cluster"]
1955 # In what cases would it be itself?
1956 if upstream != cluster and downstream != cluster:
1957 context = [upstream, downstream]
1958 contexts["__".join(context)] += 1
1959 contexts_uniq["__".join(sorted(context))] += 1
1962 num_samples_non_fragmented = len(samples_non_fragmented)
1963 summarized[cluster]["num_samples_non_fragmented"] = num_samples_non_fragmented
1965 mean_sequences_per_sample = sum(sequences_per_sample.values()) / len(sequences_per_sample)
1966 summarized[cluster]["mean_sequences_per_sample"] = round(mean_sequences_per_sample, 1)
1968 # -------------------------------------------------------------------------
1969 # Summarize upstream/downstream
1971 # Part 1. Are there neighboring loci (regardless of directionality)?
1972 neighbors = None
1973 if len(contexts_uniq) > 1:
1974 most_common, count = contexts_uniq.most_common(1)[0]
1975 # Are these loci observed frequently enough?
1976 prop = count / num_samples_non_fragmented
1977 if prop >= threshold:
1978 # Handle ties, by simply picking the first option alphabetically
1979 candidates = sorted([c for c,v in contexts_uniq.items() if v == count])
1980 neighbors = candidates[0]
1981 if len(candidates) > 1:
1982 logging.debug(f"{cluster} tie broken alphabetically: neighbors={neighbors}, {contexts_uniq}")
1984 # Part 2. Filter the neighbors to our top matches, now allowing either direction
1985 # We will summarize in the next step
1986 if neighbors != None:
1987 c1, c2 = most_common.split("__")
1988 forward, reverse = f"{c1}__{c2}", f"{c2}__{c1}"
1989 contexts = Counter({k:v for k,v in contexts.items() if k == forward or k == reverse})
1992 # -------------------------------------------------------------------------
1993 # Summarize sequence attributes
1995 for key,values in zip(["feature", "strand", "gene", "product", "name", "dbxref", "contexts"], [features, strands, genes, products, names, dbxrefs, contexts]):
1996 value = ""
1997 value_alt = ""
1999 if len(values) > 0:
2000 # Check if the most common value passes the threshold
2001 most_common_count = values.most_common(1)[0][1]
2002 most_common_prop = most_common_count / num_samples_non_fragmented
2003 # We don't do a second threshold filter for contexts
2004 if key == "contexts" or most_common_prop >= threshold :
2005 # Handle ties, by simply picking the first option alphabetically
2006 candidates = sorted([c for c,v in values.items() if v == most_common_count])
2007 value = candidates[0]
2008 if len(candidates) > 1:
2009 logging.debug(f"{cluster} tie broken alphabetically: {key}={value}, {values}")
2010 value_alt = ";".join([v for v in values.keys() if v != value])
2011 else:
2012 value_alt = ";".join([v for v in values.keys()])
2014 if key == "contexts":
2015 if value != "":
2016 upstream, downstream = value.split("__")
2017 summarized[cluster]["upstream"] = upstream
2018 summarized[cluster]["downstream"] = downstream
2019 continue
2021 summarized[cluster][key] = value
2022 summarized[cluster][f"{key}_alt"] = value_alt
2024 # gene/product identifiers need to be checked for case!
2026 if key == "gene" and value != "":
2027 value_lower = value.lower()
2028 if value_lower not in gene_counts:
2029 gene_counts[value_lower] = {"count": 0, "current": 1}
2030 gene_counts[value_lower]["count"] += 1
2032 if key == "product":
2033 summarized[cluster]["product_clean"] = ""
2034 if value != "":
2035 # Clean up product as a potential identifier
2036 # These values are definitely not allowed
2037 clean_value = value.replace(" ", "_").replace("/", "_").replace(",", "_")
2038 for k,v in product_clean.items():
2039 clean_value = clean_value.replace(k, v)
2040 # Restrict the length
2041 if max_product_len != None:
2042 clean_value = clean_value[:max_product_len]
2043 while "__" in clean_value:
2044 clean_value = clean_value.replace("__", "_")
2045 while clean_value.startswith("_") or clean_value.endswith("_"):
2046 clean_value = clean_value.lstrip("_").rstrip("_")
2047 if clean_value != "":
2048 clean_value_lower = clean_value.lower()
2049 if clean_value_lower not in product_counts:
2050 product_counts[clean_value_lower] = {"count": 0, "current": 1}
2051 product_counts[clean_value_lower]["count"] += 1
2052 summarized[cluster]["product_clean"] = clean_value
2054 # Handle if feature was 'unannotated'
2055 # Can happen in tie breaking situations
2056 if (
2057 (key == "gene" or key == "product") and
2058 value != "" and
2059 summarized[cluster]["feature"] == "unannotated"
2060 ):
2061 candidates = [f for f in summarized[cluster]["feature_alt"].split(";") if f != "unannotated"]
2062 feature = candidates[0]
2063 summarized[cluster]["feature"] = feature
2064 features_alt = ["unannotated"] + [c for c in candidates if c != feature]
2065 summarized[cluster]["feature_alt"] = ";".join(features_alt)
2066 logging.debug(f"Updating {cluster} from an unannotated feature to {feature} based on {key}={value}.")
2068 summarized[cluster]["sequences"] = cluster_sequences
2070 # -------------------------------------------------------------------------
2071 # Identifiers: Part 1
2072 # -------------------------------------------------------------------------
2074 # Give identifiers to clusters based on their gene or product.
2075 # We do this now, so that the synteny graph has some nice
2076 # helpful names for clusters. We will give names to the
2077 # unnanotated clusters based on their upstream/downstream
2078 # loci after the synteny graph is made.
2080 logging.info(f"Assigning identifiers to annotated clusters.")
2082 identifiers = OrderedDict()
2084 # Pass 1. Identifiers based on gene/product
2085 for cluster,cluster_data in tqdm(summarized.items()):
2086 identifier = None
2087 # Option 1. Try to use gene as cluster identifier
2088 if cluster_data["gene"] != "":
2089 gene = cluster_data["gene"]
2090 # we use the lowercase to figure out the duplicate
2091 # number, but use the original name in the table
2092 gene_lower = gene.lower()
2093 if gene_counts[gene_lower]["count"] == 1:
2094 identifier = gene
2095 else:
2096 dup_i = gene_counts[gene_lower]["current"]
2097 gene_counts[gene_lower]["current"] += 1
2098 new_gene = f"{gene}.{dup_i}"
2099 logging.debug(f"Duplicate gene identifer found for {cluster}: {new_gene}")
2100 identifier = new_gene
2102 # Option 2. Try to use product as cluster identifier
2103 # Useful for non-gene annotations (ex. tRNA)
2104 elif cluster_data["product_clean"] != "":
2105 product = cluster_data["product_clean"]
2106 product_lower = product.lower()
2107 if product_counts[product_lower]["count"] == 1:
2108 identifier = product
2109 else:
2110 dup_i = product_counts[product_lower]["current"]
2111 product_counts[product_lower]["current"] += 1
2112 new_product = f"{product}.{dup_i}"
2113 logging.debug(f"Duplicate product identifer found for {cluster}: {new_product}")
2114 identifier = new_product
2116 if identifier != None:
2117 identifiers[cluster] = identifier
2118 summarized[cluster]["cluster"] = identifier
2120 # Pass 2: Update the upstream/downstream identifiers
2121 for cluster,cluster_data in tqdm(summarized.items()):
2122 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"]
2123 if upstream in identifiers:
2124 summarized[cluster]["upstream"] = identifiers[upstream]
2125 if downstream in identifiers:
2126 summarized[cluster]["downstream"] = identifiers[downstream]
2128 # Update the cluster keys
2129 summarized = OrderedDict({v["cluster"]:v for k,v in summarized.items()})
2131 # -------------------------------------------------------------------------
2132 # Synteny
2133 # -------------------------------------------------------------------------
2135 logging.info(f"Computing initial synteny graph.")
2137 # -------------------------------------------------------------------------
2138 # Create initial graph
2140 # Create simple graph based on what we know about the
2141 # upstream/downstream loci so far. This will be our "full"
2142 # graph which retains all the cycles and multifurcations
2143 synteny_full = networkx.DiGraph()
2144 seen_nodes = set()
2146 for cluster, cluster_data in summarized.items():
2147 if cluster not in seen_nodes:
2148 synteny_full.add_node(cluster)
2149 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"]
2150 if upstream != "" and upstream != "TERMINAL":
2151 synteny_full.add_edge(upstream, cluster)
2152 if downstream != "" and downstream != "TERMINAL":
2153 synteny_full.add_edge(cluster, downstream)
2155 # -------------------------------------------------------------------------
2156 # Filter nodes
2158 # Remove clusters from the graph. This can happen if the
2159 # user requested min_samples>0, so we need to remove the
2160 # low prevalence clusters. It can also happen if a
2161 # cluster had an upstream/downstream loci that didn't
2162 # make it to the final clustering list
2163 logging.info(f"Filtering graph for missing clusters.")
2164 for node in list(networkx.dfs_tree(synteny_full)):
2165 if node in summarized: continue
2166 in_nodes = [e[0] for e in synteny_full.in_edges(node)]
2167 out_nodes = [e[1] for e in synteny_full.out_edges(node)]
2168 # Remove this node from the graph
2169 logging.debug(f"Removing {node} from the graph.")
2170 synteny_full.remove_node(node)
2171 # Connect new edges between in -> out
2172 for n1 in [n for n in in_nodes if n not in out_nodes]:
2173 for n2 in [n for n in out_nodes if n not in in_nodes]:
2174 # Not sure if we need this check?
2175 if not synteny_full.has_edge(n1, n2):
2176 synteny_full.add_edge(n1, n2)
2178 # -------------------------------------------------------------------------
2179 # Break up multifurcations
2181 synteny_linear = copy.deepcopy(synteny_full)
2183 logging.info(f"Breaking up multifurcations.")
2184 for node in tqdm(list(networkx.dfs_tree(synteny_linear))):
2185 in_nodes = [e[0] for e in synteny_linear.in_edges(node)]
2186 out_nodes = [e[1] for e in synteny_linear.out_edges(node)]
2188 if len(in_nodes) > 1:
2189 for n in in_nodes:
2190 logging.debug(f"Removing multifurcation in_edge: {n} -> {node}")
2191 synteny_linear.remove_edge(n, node)
2192 if len(out_nodes) > 1:
2193 for n in out_nodes:
2194 logging.debug(f"Removing multifurcation out_edge: {node} -> {n}")
2195 synteny_linear.remove_edge(node, n)
2197 # -------------------------------------------------------------------------
2198 # Isolate and linear cycles
2200 # I'm not sure in what cases this is still needed
2201 # after switching the graph to directed
2202 logging.info(f"Isolating cycles and linearizing.")
2203 cycles = True
2204 while cycles == True:
2205 try:
2206 cycle_raw = [x for xs in networkx.find_cycle(synteny_linear) for x in xs]
2207 seen = set()
2208 cycle = []
2209 for n in cycle_raw:
2210 if n not in seen:
2211 cycle.append(n)
2212 seen.add(n)
2213 logging.debug(f"Cycle found: {cycle}")
2214 cycle_set = set(cycle)
2215 # Separate this cycle from the rest of the graph
2216 for n1 in synteny_linear.nodes():
2217 if n1 in cycle: continue
2218 neighbors = set(synteny_linear.neighbors(n1))
2219 for n2 in cycle_set.intersection(neighbors):
2220 logging.debug(f"Isolating cycle by removing edge: {n1} <-> {n2}")
2221 synteny_linear.remove_edge(n1, n2)
2223 # Break the final cycle between the 'first' and 'last' nodes
2224 first, last = cycle[0], cycle[-1]
2225 logging.debug(f"Breaking cycle by removing edge: {last} -> {first}")
2226 synteny_linear.remove_edge(last, first)
2228 except NetworkXNoCycle:
2229 cycles = False
2231 # -------------------------------------------------------------------------
2232 # Identify synteny blocks
2234 # We need to use the function connected_components, but that only
2235 # works on undirected graphs
2236 synteny_linear_undirected = synteny_linear.to_undirected()
2238 # Get synteny blocks, sorted largest to smallest
2239 logging.info(f"Identifying synteny blocks.")
2240 synteny_blocks = sorted([
2241 synteny_linear_undirected.subgraph(c).copy()
2242 for c in networkx.connected_components(synteny_linear_undirected)
2243 ], key=len, reverse=True)
2245 # Now we have enough information to finalize the order
2246 summarized_order = OrderedDict()
2248 for i_b, block in enumerate(tqdm(synteny_blocks)):
2249 i_b += 1
2250 clusters = list(block)
2252 if len(clusters) > 1:
2253 terminals = [n for n in block.nodes if len(list(block.neighbors(n))) == 1]
2254 # Legacy error from troubleshooting, unclear if still relevant
2255 if len(terminals) != 2:
2256 # Check if it's a cycle we somehow didn't handle
2257 try:
2258 cycle = networkx.find_cycle(block)
2259 cycle = True
2260 except NetworkXNoCycle:
2261 cycle = False
2262 msg = f"Synteny block {i_b} has an unhandled error, cycle={cycle}, terminals={len(terminals)}: {terminals}"
2263 logging.error(msg)
2264 print(networkx.write_network_text(block))
2265 raise Exception(msg)
2267 # Figure out which terminal is the 5' end
2268 first_upstream = summarized[terminals[0]]["upstream"]
2269 last_upstream = summarized[terminals[-1]]["upstream"]
2271 if first_upstream == "TERMINAL":
2272 first, last = terminals[0], terminals[1]
2273 elif last_upstream == "TERMINAL":
2274 first, last = terminals[1], terminals[0]
2275 else:
2276 # If it's fully ambiguous, we'll sort the terminals for reproducibility
2277 terminals = sorted(terminals)
2278 first, last = terminals[0], terminals[1]
2280 # Manually walk through the graph, this is not ideal
2281 # but networkx's bfs/dfs was giving me odd values in testing
2282 clusters, neighbors = [first], []
2283 curr_node, next_node = first, None
2284 while next_node != last:
2285 neighbors = [n for n in block.neighbors(curr_node) if n not in clusters]
2286 # Legacy error from troubleshooting, unclear if still relevant
2287 if len(neighbors) != 1:
2288 msg = f"Synteny error, unhandled multifurcation in {curr_node}: {neighbors}"
2289 logging.error(msg)
2290 raise Exception(msg)
2291 next_node = neighbors[0]
2292 clusters.append(next_node)
2293 curr_node = next_node
2295 for i_c, cluster in enumerate(clusters):
2296 summarized_order[cluster] = summarized[cluster]
2297 summarized_order[cluster]["synteny"] = str(i_b)
2298 summarized_order[cluster]["synteny_pos"] = str(i_c + 1)
2300 upstream = summarized_order[cluster]["upstream"]
2301 downstream = summarized_order[cluster]["downstream"]
2303 # Use the synteny block to finalize upstream/downstream
2304 upstream = summarized_order[cluster]["upstream"]
2305 downstream = summarized_order[cluster]["downstream"]
2307 # We'll save a copy of how it was before as the alt
2308 upstream_orig, downstream_orig = copy.deepcopy(upstream), copy.deepcopy(downstream)
2310 # 'TERMINAL' will now refer to the ends of the synteny block
2311 if i_c > 0:
2312 upstream = clusters[i_c - 1]
2313 else:
2314 upstream = "TERMINAL"
2316 if i_c < (len(clusters) - 1):
2317 downstream = clusters[i_c + 1]
2318 else:
2319 downstream = "TERMINAL"
2321 if upstream != upstream_orig:
2322 summarized_order[cluster]["upstream_alt"] = upstream_orig
2323 if downstream != downstream_orig:
2324 summarized_order[cluster]["downstream_alt"] = downstream_orig
2326 summarized_order[cluster]["upstream"] = upstream
2327 summarized_order[cluster]["downstream"] = downstream
2329 summarized = summarized_order
2331 # -------------------------------------------------------------------------
2332 # Create directed
2334 logging.info(f"Converting synteny blocks to directed graph.")
2336 # Now we need to go back to a directed graph
2337 synteny_linear_directed = networkx.DiGraph()
2338 synteny_seen = set()
2340 for cluster,cluster_data in summarized.items():
2341 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"]
2342 if cluster not in synteny_seen:
2343 synteny_seen.add(cluster)
2344 synteny_linear_directed.add_node(cluster)
2346 if upstream != "TERMINAL":
2347 if upstream not in synteny_seen:
2348 synteny_seen.add(upstream)
2349 synteny_linear_directed.add_edge(upstream, cluster)
2351 if downstream != "TERMINAL":
2352 if downstream not in synteny_seen:
2353 synteny_seen.add(downstream)
2354 synteny_linear_directed.add_edge(cluster, downstream)
2356 # -------------------------------------------------------------------------
2357 # Identifiers: Part 2
2358 # -------------------------------------------------------------------------
2360 # Give unannotated clusters identifiers based on their upstream/downstream
2361 # loci in the synteny graph.
2363 # Note: the synteny reconstruction ensures that we're not going to have
2364 # any duplicate cluster IDs for the unannotated regions
2365 # because any loops/multifurcations have been removed.
2367 logging.info(f"Assigning identifiers to unannotated clusters.")
2369 # -------------------------------------------------------------------------
2370 # Pass #1: Give identifiers to unannotated clusters based on upstream/downstream
2372 for cluster,cluster_data in tqdm(summarized.items()):
2373 # Skip this cluster if it already has a new identifier
2374 # Example, annotated clusters based on gene/product
2375 if cluster in identifiers and identifiers[cluster] != cluster:
2376 continue
2377 # Skip this cluster if it has gene/product info, just a safety fallback
2378 if cluster_data["gene"] != "" or cluster_data["product"] != "":
2379 continue
2380 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"]
2381 # Option 1. No known neighbors
2382 if upstream == "TERMINAL" and downstream == "TERMINAL":
2383 identifier = cluster
2384 # Option 2. Upstream/downstream already has the notation
2385 elif "__" in upstream or "__" in downstream:
2386 identifier = cluster
2387 # Option 3. At least one side is known
2388 else:
2389 identifier = f"{upstream}__{downstream}"
2391 identifiers[cluster] = identifier
2392 summarized[cluster]["cluster"] = identifier
2394 # -------------------------------------------------------------------------
2395 # Pass #2: Finalize upstream/downstream
2397 # Now that we know the identifiers, we need to update the
2398 # following fields: cluster, upstream, and downstream
2400 logging.info(f"Finalizing upstream/downstream identifiers.")
2402 for cluster, cluster_data in tqdm(summarized.items()):
2403 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"]
2404 if upstream in identifiers:
2405 summarized[cluster]["upstream"] = upstream
2406 if downstream in identifiers:
2407 summarized[cluster]["downstream"] = downstream
2409 # Update the keys in the graph
2410 summarized = OrderedDict({v["cluster"]:v for k,v in summarized.items()})
2412 # -------------------------------------------------------------------------
2413 # Update synteny graph with new identifiers
2414 # -------------------------------------------------------------------------
2416 logging.info(f"Updating cluster identifiers in the synteny graphs.")
2418 # We will both update the original "full" graph, as well as our new
2419 # "linear" graph
2421 # -------------------------------------------------------------------------
2422 # Full Graph
2424 networkx.relabel_nodes(synteny_full, mapping=identifiers)
2425 edges = list(synteny_full.out_edges())
2426 for c1, c2 in edges:
2427 synteny_full.remove_edge(c1, c2)
2428 c1 = identifiers[c1] if c1 in identifiers else c1
2429 c2 = identifiers[c2] if c2 in identifiers else c2
2430 synteny_full.add_edge(c1, c2)
2432 synteny_full_path = os.path.join(output_dir, f"{prefix}synteny.full.graphml")
2433 logging.info(f"Writing full synteny GraphML: {synteny_full_path}")
2434 networkx.write_graphml(synteny_full, synteny_full_path)
2436 gfa_path = os.path.join(output_dir, f"{prefix}synteny.full.gfa")
2437 logging.info(f"Writing full synteny GFA: {gfa_path}")
2439 with open(gfa_path, 'w') as outfile:
2440 outfile.write("H\tVN:Z:1.0\n")
2441 for cluster in summarized:
2442 outfile.write(f"S\t{cluster}\t*\tLN:i:1\n")
2443 for c1, c2 in synteny_full.out_edges():
2444 c1_strand, c2_strand = summarized[c1]["strand"], summarized[c2]["strand"]
2445 outfile.write(f"L\t{c1}\t{c1_strand}\t{c2}\t{c2_strand}\t0M\n")
2447 # -------------------------------------------------------------------------
2448 # Linearized Graph
2450 networkx.relabel_nodes(synteny_linear_directed, mapping=identifiers)
2451 edges = list(synteny_linear_directed.out_edges())
2453 for c1, c2 in edges:
2454 synteny_linear_directed.remove_edge(c1, c2)
2455 c1 = identifiers[c1] if c1 in identifiers else c1
2456 c2 = identifiers[c2] if c2 in identifiers else c2
2457 synteny_linear_directed.add_edge(c1, c2)
2459 synteny_linear_path = os.path.join(output_dir, f"{prefix}synteny.linear.graphml")
2460 logging.info(f"Writing linear synteny GraphML: {synteny_linear_path}")
2461 networkx.write_graphml(synteny_linear_directed, synteny_linear_path)
2463 gfa_path = os.path.join(output_dir, f"{prefix}synteny.linear.gfa")
2464 logging.info(f"Writing linear synteny GFA: {gfa_path}")
2466 with open(gfa_path, 'w') as outfile:
2467 outfile.write("H\tVN:Z:1.0\n")
2468 for cluster in summarized:
2469 outfile.write(f"S\t{cluster}\t*\tLN:i:1\n")
2470 for c1, c2 in synteny_linear_directed.out_edges():
2471 c1_strand, c2_strand = summarized[c1]["strand"], summarized[c2]["strand"]
2472 outfile.write(f"L\t{c1}\t{c1_strand}\t{c2}\t{c2_strand}\t0M\n")
2474 # -------------------------------------------------------------------------
2475 # Write Output tsv
2477 tsv_path = os.path.join(output_dir, f"{prefix}clusters.tsv")
2478 logging.info(f"Writing summarized clusters tsv: {tsv_path}")
2480 with open(tsv_path, 'w') as outfile:
2481 header = None
2482 for i, cluster in enumerate(tqdm(summarized)):
2483 cluster_data = summarized[cluster]
2484 if not header:
2485 header = [k for k in cluster_data if k != "sequences" and k != "product_clean"] + all_samples
2486 outfile.write("\t".join(header) + "\n")
2487 row = [str(v) for k,v in cluster_data.items() if k != "sequences" and k != "product_clean"]
2488 # Add info about which sequences map to each sample in the cluster
2489 sample_to_seq_id = OrderedDict({sample:[] for sample in all_samples})
2490 for seq_id,seq_data in cluster_data["sequences"].items():
2491 sample = seq_data["sample"]
2492 sample_to_seq_id[sample].append(seq_id)
2493 for sample in all_samples:
2494 row += [",".join(sample_to_seq_id[sample])]
2496 outfile.write("\t".join(row) + "\n")
2498 # -------------------------------------------------------------------------
2499 # Write table (for phandango)
2501 phandango_path = os.path.join(output_dir, f"{prefix}phandango.csv")
2502 logging.info(f"Writing table for phandango: {phandango_path}")
2504 logging.info(f"Sorting synteny blocks according to number of samples: {phandango_path}")
2505 syntenies = OrderedDict()
2506 for cluster,c_data in summarized.items():
2507 synteny = c_data["synteny"]
2508 if synteny not in syntenies:
2509 syntenies[synteny] = {"max_samples": 0, "clusters": []}
2510 num_samples = c_data["num_samples"]
2511 syntenies[synteny]["max_samples"] = max(num_samples, syntenies[synteny]["max_samples"])
2512 syntenies[synteny]["clusters"].append(cluster)
2514 syntenies = OrderedDict(
2515 sorted(
2516 syntenies.items(),
2517 key=lambda item: item[1]["max_samples"], reverse=True
2518 )
2519 )
2521 # This it the roary gene_presence_absence.csv format
2522 with open(phandango_path, 'w') as outfile:
2523 header = [
2524 "Gene","Non-unique Gene name","Annotation","No. isolates",
2525 "No. sequences","Avg sequences per isolate","Genome Fragment",
2526 "Order within Fragment","Accessory Fragment",
2527 "Accessory Order with Fragment","QC"
2528 ] + all_samples
2529 outfile.write(",".join(header) + "\n")
2530 for synteny,s_data in tqdm(syntenies.items()):
2531 for cluster in s_data["clusters"]:
2532 c_data = summarized[cluster]
2533 data = OrderedDict({k:"" for k in header})
2534 # This is deliberately reversed
2535 data["Gene"] = c_data["cluster"]
2536 data["Non-unique Gene name"] = c_data["cluster_id"]
2537 # We're going to use the cleaned product, because
2538 # we need at least some bad characters to removed
2539 # (like commas)
2540 data["Annotation"] = c_data["product_clean"]
2541 data["No. isolates"] = c_data["num_samples"]
2542 data["No. sequences"] = cluster_data["num_sequences"]
2543 data["Avg sequences per isolate"] = c_data["mean_sequences_per_sample"]
2544 data["Genome Fragment"] = c_data["synteny"]
2545 data["Order within Fragment"] = c_data["synteny_pos"]
2546 # If a sample has a sequence then it will be given a "1"
2547 # If not, is recorded as the empty string ""
2548 # This is based on the file minimizing back from phandango:
2549 # https://github.com/jameshadfield/phandango/blob/master/scripts/minimiseROARY.py
2550 for s_data in c_data["sequences"].values():
2551 sample = s_data["sample"]
2552 data[sample] = "1"
2553 line = ",".join([str(v) for v in data.values()])
2554 outfile.write(line + "\n")
2556 return tsv_path
2559# Parallel
2560def run_mafft(kwargs: dict):
2561 """A wrapper function to pass multi-threaded pool args to mafft."""
2562 cmd, output, quiet, display_cmd = [kwargs[k] for k in ["cmd", "output", "quiet", "display_cmd"]]
2563 run_cmd(cmd=cmd, output=output, quiet=quiet, display_cmd=display_cmd)
2565def align(
2566 clusters: str,
2567 regions: str,
2568 outdir: str = ".",
2569 prefix: str = None,
2570 exclude_singletons: bool = False,
2571 threads: int = 1,
2572 args: str = ALIGN_ARGS,
2573 ):
2574 """
2575 Align clusters using mafft and create a pangenome alignment.
2577 Takes as input the summarized clusters from summarize and sequence regions from collect.
2578 Outputs multiple sequence alignments per cluster as well as a pangenome alignment of
2579 concatenated clusters.
2581 >>> align(clusters="summarize.clusters.tsv", sequences="sequences.tsv")
2582 >>> align(clusters="summarize.clusters.tsv", sequences="sequences.tsv", exclude_singletons=True, args="--localpair")
2584 :param clusters: TSV file of clusters from summarize.
2585 :param sequences: TSV file of sequence regions from collect.
2586 :param outdir: Output directory.
2587 :param prefix: Prefix for output files.
2588 :param exclude_singletons: True is clusters found in only one sample should be excluded.
2589 :param threads: Number of cpu threads to parallelize mafft across.
2590 :param args: Additional arguments for MAFFT.
2591 """
2593 from multiprocessing import get_context
2595 clusters_path, sequences_path, output_dir = clusters, regions, outdir
2597 # Check output directory
2598 check_output_dir(output_dir)
2600 args = args if args != None else ""
2601 prefix = f"{prefix}." if prefix != None else ""
2602 threads = int(threads)
2604 # -------------------------------------------------------------------------
2605 # Read Sequence Regions
2607 all_samples = []
2608 sequences = {}
2610 logging.info(f"Reading sequence regions: {sequences_path}")
2611 with open(sequences_path) as infile:
2612 header = [line.strip() for line in infile.readline().split("\t")]
2613 for line in infile:
2614 row = [l.strip() for l in line.split("\t")]
2615 data = {k:v for k,v in zip(header, row)}
2616 sample, sequence_id = data["sample"], data["sequence_id"]
2617 if sample not in all_samples:
2618 all_samples.append(sample)
2619 sequences[sequence_id] = data
2621 # -------------------------------------------------------------------------
2622 # Read Summarized Clusters
2624 logging.info(f"Reading summarized clusters: {clusters_path}")
2625 clusters = OrderedDict()
2626 all_samples = []
2627 with open(clusters_path) as infile:
2628 header = [line.strip() for line in infile.readline().split("\t")]
2629 # sample columns begin after dbxref_alt
2630 all_samples = header[header.index("dbxref_alt")+1:]
2631 lines = infile.readlines()
2632 for line in tqdm(lines):
2633 row = [l.strip() for l in line.split("\t")]
2634 data = {k:v for k,v in zip(header, row)}
2635 cluster = data["cluster"]
2636 clusters[cluster] = data
2637 clusters[cluster]["sequences"] = OrderedDict()
2638 for sample in all_samples:
2639 sequence_ids = data[sample].split(",") if data[sample] != "" else []
2640 # Add the sequence IDs, we will associate it with the actual
2641 # sequence from the collect TSV
2642 for sequence_id in sequence_ids:
2643 clusters[cluster]["sequences"][sequence_id] = {
2644 "sequence": sequences[sequence_id]["sequence"],
2645 "sample": sample
2646 }
2648 # -------------------------------------------------------------------------
2649 # Write Sequences
2651 representative_output = os.path.join(output_dir, prefix + "representative")
2652 sequences_output = os.path.join(output_dir, prefix + "sequences")
2653 alignments_output = os.path.join(output_dir, prefix + "alignments")
2654 consensus_output = os.path.join(output_dir, prefix + "consensus")
2656 check_output_dir(representative_output)
2657 check_output_dir(sequences_output)
2658 check_output_dir(alignments_output)
2659 check_output_dir(consensus_output)
2661 # -------------------------------------------------------------------------
2662 # Write Representative sequences for each cluster to file
2664 skip_align = set()
2665 clusters_exclude_singletons = OrderedDict()
2667 logging.info(f"Writing representative sequences: {representative_output}")
2668 for cluster,cluster_data in tqdm(clusters.items()):
2669 rep_seq_id = cluster_data["representative"]
2670 rep_seq = cluster_data["sequences"][rep_seq_id]["sequence"]
2671 rep_path = os.path.join(representative_output, f"{cluster}.fasta")
2672 samples = list(set([s["sample"] for s in cluster_data["sequences"].values()]))
2674 # If only found in 1 sample, and the user requested to exclude these, skip
2675 if len(samples) == 1 and exclude_singletons:
2676 logging.debug(f"Skipping singleton cluster: {cluster}")
2677 continue
2679 # If this is cluster only has one sequence, write to final output
2680 if len(cluster_data["sequences"]) == 1:
2681 skip_align.add(cluster)
2682 file_path = os.path.join(alignments_output, f"{cluster}.aln")
2683 with open(file_path, 'w') as outfile:
2684 outfile.write(f">{rep_seq_id}\n{rep_seq}\n")
2685 # Otherwise, save rep seq to separate folder
2686 # We might need it if user has requeseted the --add* mafft args
2687 else:
2688 with open(rep_path, 'w') as outfile:
2689 outfile.write(f">{rep_seq_id}\n{rep_seq}\n")
2691 clusters_exclude_singletons[cluster] = cluster_data
2693 clusters = clusters_exclude_singletons
2695 # -------------------------------------------------------------------------
2696 # Write DNA sequences for each cluster to file
2698 # A queue of commands to submit to mafft in parallel
2699 mafft_queue = []
2701 logging.info(f"Writing cluster sequences: {sequences_output}")
2702 for i,cluster in enumerate(tqdm(clusters)):
2703 # Skip singleton clusters, we already handled them in previous block
2704 if cluster in skip_align: continue
2705 cluster_data = clusters[cluster]
2706 rep_seq_id = cluster_data["representative"]
2707 representative_path = os.path.join(representative_output, f"{cluster}.fasta")
2708 sequences_path = os.path.join(sequences_output, f"{cluster}.fasta")
2709 with open(sequences_path, "w") as outfile:
2710 for sequence_id,seq_data in cluster_data["sequences"].items():
2711 # Skip representative if we're using addfragments
2712 if "--add" in args and sequence_id == rep_seq_id:
2713 continue
2714 sequence = seq_data["sequence"]
2715 line = f">{sequence_id}\n{sequence}\n"
2716 outfile.write(line)
2717 alignment_path = os.path.join(alignments_output, f"{cluster}.aln")
2718 cmd = f"mafft --thread 1 {args} {sequences_path}"
2719 # If we're using addfragments, we're aligning against representative/reference
2720 if "--add" in args:
2721 cmd += f" {representative_path}"
2722 mafft_queue.append({"cmd": cmd, "output": alignment_path, "quiet": True, "display_cmd": False})
2724 # Display first command
2725 if len(mafft_queue) > 0:
2726 logging.info(f"Command to run in parallel: {mafft_queue[0]['cmd']}")
2728 # -------------------------------------------------------------------------
2729 # Align DNA sequences with MAFFT
2731 # Parallel: MAFFT is very CPU-bound, so parallel is appropriate
2732 logging.info(f"Aligning cluster sequences in parallel with {threads} threads: {alignments_output}")
2733 with get_context('fork').Pool(threads) as p:
2734 with tqdm(total=len(mafft_queue), unit="cluster") as bar:
2735 for _ in p.imap_unordered(run_mafft, mafft_queue):
2736 bar.update()
2738 # -------------------------------------------------------------------------
2739 # Unwrap and dedup cluster alignments
2741 # Duplicates occur due to multi-copy/fragments. We'll try to dedup the
2742 # sequence by reconstructing consensus bases where possible.
2744 logging.info(f"Unwrapping and dedupping alignments: {consensus_output}")
2745 consensus_alignments = OrderedDict()
2747 for cluster,cluster_data in tqdm(clusters.items()):
2748 original_path = os.path.join(alignments_output, cluster + ".aln")
2749 tmp_path = original_path + ".tmp"
2750 consensus_path = os.path.join(consensus_output, cluster + ".aln")
2752 # Unwrap sequences and convert to uppercase
2753 alignment = {}
2754 with open(original_path, 'r') as infile:
2755 with open(tmp_path, 'w') as outfile:
2756 records = infile.read().split(">")[1:]
2757 for record in records:
2758 record_split = record.split("\n")
2759 sequence_id = record_split[0]
2760 # Check for indicator that it was reverse complemented
2761 if sequence_id not in cluster_data["sequences"] and sequence_id.startswith("_R_"):
2762 sequence_id = sequence_id[3:]
2764 sample = cluster_data["sequences"][sequence_id]["sample"]
2765 sequence = "".join(record_split[1:]).replace("\n", "").upper()
2766 if sample not in alignment:
2767 alignment[sample] = []
2768 alignment[sample].append(sequence)
2770 # Write the sequence to the output file, using the original header
2771 outfile.write(f">{sequence_id}\n{sequence}\n")
2773 # Replace the original file, now that unwrapped and uppercased
2774 # We can use this in a different script for structural variant detection
2775 shutil.move(tmp_path, original_path)
2777 if len(alignment) == 0:
2778 logging.info(f"WARNING: No sequences written for cluster: {cluster}")
2779 continue
2781 # Create consensus sequence from fragments and multi-copies
2782 duplicate_samples = [sample for sample,seqs in alignment.items() if len(seqs) > 1]
2783 # Create consensus alignment, first with the non-duplicated samples
2784 alignment_consensus = {sample:seqs[0] for sample,seqs in alignment.items() if len(seqs) == 1}
2786 for sample in duplicate_samples:
2787 seqs = alignment[sample]
2788 length = len(seqs[0])
2789 consensus = []
2791 # Reconstruct Consensus sequence
2792 for i in range(0, length):
2793 nuc_raw = set([s[i] for s in seqs])
2794 nuc = list(set([n for n in nuc_raw if n in NUCLEOTIDES] ))
2795 if len(nuc) == 1: nuc_consensus = nuc[0]
2796 elif nuc_raw == set("-"): nuc_consensus = "-"
2797 else: nuc_consensus = "N"
2798 consensus.append(nuc_consensus)
2800 consensus_sequence = "".join(consensus)
2801 alignment_consensus[sample] = consensus_sequence
2803 consensus_alignments[cluster] = alignment_consensus
2805 # Write unwrapped, dedupped, defragged, consensus alignment
2806 with open(consensus_path, 'w') as outfile:
2807 for (sample, sequence) in alignment_consensus.items():
2808 outfile.write(f">{sample}\n{sequence}\n")
2810 # -------------------------------------------------------------------------
2811 # Concatenate into pangenome alignment
2813 logging.info(f"Creating pangenome alignment.")
2815 pangenome = {
2816 "bed" : {},
2817 "alignment" : {s: [] for s in all_samples},
2818 }
2820 curr_pos = 0
2822 for cluster,alignment in tqdm(consensus_alignments.items()):
2823 # identify samples missing from cluster
2824 observed_samples = list(alignment.keys())
2825 missing_samples = [s for s in all_samples if s not in alignment]
2827 # write gaps for missing samples
2828 seq_len = len(alignment[observed_samples[0]])
2829 for sample in missing_samples:
2830 alignment[sample] = "-" * seq_len
2832 # concatenate cluster sequence to phylo alignment
2833 for sample, seq in alignment.items():
2834 pangenome["alignment"][sample].append(seq)
2836 # update bed coordinates
2837 prev_pos = curr_pos
2838 curr_pos = curr_pos + seq_len
2840 pangenome["bed"][prev_pos] = {
2841 "start" : prev_pos,
2842 "end" : curr_pos,
2843 "cluster" : cluster,
2844 "synteny": clusters[cluster]["synteny"]
2845 }
2847 # Final concatenation of alignment
2848 logging.info("Performing final concatenation.")
2849 for sample in tqdm(all_samples):
2850 pangenome["alignment"][sample] = "".join( pangenome["alignment"][sample])
2852 # -------------------------------------------------------------------------
2853 # Write pangenome Consensus Sequence
2855 consensus_file_path = os.path.join(output_dir, prefix + "pangenome.consensus.fasta")
2856 logging.info(f"Writing pangenome consensus: {consensus_file_path}")
2857 with open(consensus_file_path, 'w') as out_file:
2858 out_file.write(">consensus\n")
2859 # get total length from first sample in alignment
2860 first_sample = list(pangenome["alignment"].keys())[0]
2861 consensus_len = len(pangenome["alignment"][first_sample])
2862 consensus = []
2863 logging.info(f"Consensus Length: {consensus_len}")
2864 for cluster_start,bed_info in tqdm(pangenome["bed"].items()):
2865 cluster_end, cluster = bed_info["end"], bed_info["cluster"]
2866 for i in range(cluster_start, cluster_end):
2867 nuc = [pangenome["alignment"][s][i] for s in all_samples]
2868 nuc_non_ambig = [n for n in nuc if n in NUCLEOTIDES]
2869 # All samples excluded due to ambiguity
2870 if len(set(nuc_non_ambig)) == 0:
2871 consensus_nuc = "N"
2872 # Invariant position, all the same
2873 elif len(set(nuc_non_ambig)) == 1:
2874 consensus_nuc = nuc[0]
2875 # Variant position, get consensus of nonambiguous nucleotides
2876 else:
2877 # choose which ever nucleotide has highest count as the consensus
2878 counts = {n:nuc.count(n) for n in NUCLEOTIDES}
2879 consensus_nuc = max(counts, key=counts.get)
2880 consensus.append(consensus_nuc)
2881 # write final end of line
2882 out_file.write("".join(consensus) + "\n")
2884 # -------------------------------------------------------------------------
2885 # Write Bed File
2887 bed_file_path = os.path.join(output_dir, prefix + "pangenome.bed")
2888 logging.info(f"Writing bed file: {bed_file_path}")
2889 with open(bed_file_path, "w") as bed_file:
2890 for start,info in tqdm(pangenome["bed"].items()):
2891 cluster, synteny = info["cluster"], info["synteny"]
2892 name = f"cluster={cluster};synteny={synteny}"
2893 row = ["pangenome", start, info["end"], name]
2894 line = "\t".join([str(val) for val in row])
2895 bed_file.write(line + "\n")
2897 # -------------------------------------------------------------------------
2898 # Write Alignment
2900 aln_file_path = os.path.join(output_dir, prefix + "pangenome.aln")
2901 logging.info(f"Writing alignment file: {aln_file_path}")
2902 with open(aln_file_path, "w") as aln_file:
2903 for sample in tqdm(all_samples):
2904 seq = pangenome["alignment"][sample].upper()
2905 aln_file.write(">" + sample + "\n" + seq + "\n")
2907 # -------------------------------------------------------------------------
2908 # Write GFF
2910 gff_file_path = os.path.join(output_dir, prefix + "pangenome.gff3")
2911 logging.info(f"Writing gff file: {gff_file_path}")
2912 with open(gff_file_path, "w") as outfile:
2913 outfile.write("##gff-version 3\n")
2914 outfile.write(f"##sequence-region pangenome 1 {consensus_len}\n")
2915 row = ["pangenome", ".", "region", 1, consensus_len, ".", "+", ".", "ID=pangenome;Name=pangenome"]
2916 line = "\t".join([str(v) for v in row])
2917 outfile.write(line + "\n")
2919 for start in pangenome["bed"]:
2920 start, end, cluster, synteny = pangenome["bed"][start].values()
2921 cluster_data = clusters[cluster]
2922 representative, feature = cluster_data["representative"], cluster_data["feature"]
2923 strand = sequences[representative]["strand"]
2924 attributes = OrderedDict({"ID": cluster, "locus_tag": cluster})
2925 for col,k in zip(["gene", "name", "product", "dbxref"], ["gene", "Name", "product", "Dbxref"]):
2926 if cluster_data[col] != "":
2927 attributes[k] = cluster_data[col]
2928 attributes[synteny] = synteny
2929 attributes = ";".join([f"{k}={v}" for k,v in attributes.items()])
2930 row = ["pangenome", ".", feature, start, end, ".", strand, ".", attributes ]
2931 line = "\t".join([str(v) for v in row])
2932 outfile.write(line + "\n")
2935def get_ranges(numbers):
2936 """
2937 Author: user97370, bossylobster
2938 Source: https://stackoverflow.com/a/4629241
2939 """
2940 for _a, b in itertools.groupby(enumerate(numbers), lambda pair: pair[1] - pair[0]):
2941 b = list(b)
2942 yield b[0][1], b[-1][1]
2945def collapse_ranges(ranges, min_len: int, min_gap: int):
2946 """Collapse a list of ranges into overlaps."""
2947 collapse = []
2948 curr_range = None
2949 # Collapse based on min gap
2950 for i, coord in enumerate(ranges):
2951 start, stop = coord
2952 if curr_range:
2953 gap_size = start - curr_range[1] - 1
2954 # Collapse tiny gap into previous fragment
2955 if gap_size < min_gap:
2956 curr_range = (curr_range[0], stop)
2957 # Otherwise, start new range
2958 else:
2959 collapse.append(curr_range)
2960 curr_range = (start, stop)
2961 else:
2962 curr_range = (start, stop)
2963 # Last one
2964 if curr_range not in collapse and i == len(ranges) - 1:
2965 collapse.append(curr_range)
2966 # Collapse based on min length size
2967 collapse = [(start, stop) for start,stop in collapse if (stop - start + 1) >= min_len]
2968 return collapse
2971def structural(
2972 clusters: str,
2973 alignments: str,
2974 outdir: str = ".",
2975 prefix: str = None,
2976 min_len: int = 10,
2977 min_indel_len: int = 3,
2978 args: str = None,
2979 ):
2980 """
2981 Extract structural variants from cluster alignments.
2983 Takes as input the summarized clusters TSV and their individual alignments.
2984 Outputs an Rtab file of structural variants.
2986 >>> structural(clusters="summarize.clusters.tsv", alignments="alignments")
2987 >>> structural(clusters="summarize.tsv", alignments="alignments", min_len=100, min_indel_len=10)
2989 :param clusters: TSV file of clusters from summarize.
2990 :param alignments: Directory of alignments from align.
2991 :param outdir: Output directory.
2992 :param prefix: Prefix for output files.
2993 :param min_len: Minimum length of structural variants to extract.
2994 :param min_indel_len: Minimum length of gaps that should separate variants.
2995 :param args: Str of additional arguments [not implemented]
2997 :return: Ordered Dictionary of structural variants.
2998 """
3000 from Bio import SeqIO
3002 clusters_path, alignments_dir, output_dir = clusters, alignments, outdir
3003 args = args if args != None else ""
3004 prefix = f"{prefix}." if prefix != None else ""
3006 # Check output directory
3007 check_output_dir(output_dir)
3009 variants = OrderedDict()
3011 # -------------------------------------------------------------------------
3012 # Read Clusters
3014 logging.info(f"Reading summarized clusters: {clusters_path}")
3015 all_samples = []
3017 with open(clusters_path) as infile:
3018 header = [line.strip() for line in infile.readline().split("\t")]
3019 # sample columns begin after dbxref_alt
3020 all_samples = header[header.index("dbxref_alt")+1:]
3021 lines = infile.readlines()
3022 for line in tqdm(lines):
3023 row = [l.strip() for l in line.split("\t")]
3024 data = {k:v for k,v in zip(header, row)}
3025 cluster = data["cluster"]
3027 # Get mapping of sequence IDs to samples
3028 seq_to_sample = {}
3029 for sample in all_samples:
3030 sequence_ids = data[sample].split(",") if data[sample] != "" else []
3031 for sequence_id in sequence_ids:
3032 seq_to_sample[sequence_id] = sample
3034 # Read cluster alignment from the file
3035 alignment_path = os.path.join(alignments_dir, f"{cluster}.aln")
3036 if not os.path.exists(alignment_path):
3037 # Alignment might not exist if user excluded singletons in align
3038 num_samples = len(list(set(seq_to_sample.values())))
3039 if num_samples == 1:
3040 logging.debug(f"Singleton {cluster} alignment was not found: {alignment_path}")
3041 continue
3042 # otherwise, this alignment should exist, stop here
3043 else:
3044 msg = f"{cluster} alignment was not found: {alignment_path}"
3045 logging.error(msg)
3046 raise Exception(msg)
3048 alignment = OrderedDict()
3050 for record in SeqIO.parse(alignment_path, "fasta"):
3051 sample = seq_to_sample[record.id]
3052 if sample not in alignment:
3053 alignment[sample] = OrderedDict()
3054 alignment[sample][record.id] = record.seq
3056 # Parse out the fragment pieces based on the pattern of '-'
3057 cluster_variants = OrderedDict()
3059 present_samples = set()
3060 for sample in alignment:
3061 sample_frag_ranges = []
3062 for seq in alignment[sample].values():
3063 non_ambig_nuc = set([nuc for nuc in seq if nuc in NUCLEOTIDES])
3064 if len(non_ambig_nuc) > 0:
3065 present_samples.add(sample)
3066 frag_ranges = [(start+1, stop+1) for start,stop in list(get_ranges([i for i,nuc in enumerate(seq) if nuc != "-"]))]
3067 frag_ranges_collapse = collapse_ranges(ranges=frag_ranges, min_len=min_len, min_gap=min_indel_len)
3068 sample_frag_ranges.append(frag_ranges_collapse)
3070 # Sort them for output consistency
3071 sample_frag_ranges = sorted(sample_frag_ranges)
3073 for frag in sample_frag_ranges:
3074 # Simple presence
3075 frag_text = f"{cluster}|structural:" + "_".join([f"{start}-{stop}" for start,stop in frag])
3076 if frag_text not in cluster_variants:
3077 cluster_variants[frag_text] = []
3078 if sample not in cluster_variants[frag_text]:
3079 cluster_variants[frag_text].append(sample)
3080 # Copy number (ex. 2X)
3081 count = sample_frag_ranges.count(frag)
3082 if count > 1:
3083 copy_number_text = f"{frag_text}|{count}X"
3084 if copy_number_text not in cluster_variants:
3085 cluster_variants[copy_number_text] = []
3086 if sample not in cluster_variants[copy_number_text]:
3087 cluster_variants[copy_number_text].append(sample)
3089 # If there is only one structure, it's not variant
3090 if len(cluster_variants) <= 1: continue
3091 # Samples that are truly missing will be given a "."
3092 missing_samples = [s for s in all_samples if s not in present_samples]
3093 for variant,samples in cluster_variants.items():
3094 if len(samples) == len(all_samples): continue
3095 variants[variant] = {"present": samples, "missing": missing_samples}
3097 # -------------------------------------------------------------------------
3098 # Write Structural Variants Rtab
3100 rtab_path = os.path.join(output_dir, f"{prefix}structural.Rtab")
3101 logging.info(f"Writing variants: {rtab_path}")
3103 with open(rtab_path, 'w') as outfile:
3104 header = ["Variant"] + all_samples
3105 outfile.write("\t".join(header) + "\n")
3106 for variant,data in tqdm(variants.items()):
3107 observations = ["1" if s in data["present"] else "." if s in data["missing"] else "0" for s in all_samples]
3108 line = "\t".join([variant] + observations)
3109 outfile.write(line + "\n")
3111 return rtab_path
3114def snps(
3115 alignment: str,
3116 bed: str,
3117 consensus: str,
3118 outdir: str = ".",
3119 prefix: str = None,
3120 structural: str = None,
3121 core: float = 0.95,
3122 indel_window: int = 0,
3123 snp_window: int = 0,
3124 args: str = None,
3125 ):
3126 """
3127 Extract SNPs from a pangenome alignment.
3129 Takes as input the pangenome alignment fasta, bed, and consensus file.
3130 Outputs an Rtab file of SNPs.
3132 >>> snps(alignment="pangenome.aln", bed="pangenome.bed", consensus="pangenome.consensus.fasta")
3133 >>> snps(alignment="pangenome.aln", bed="pangenome.bed", consensus="pangenome.consensus.fasta", structural="structural.Rtab", indel_window=10, snp_window=3)
3135 :param alignment: FASTA file of the pangenome alignment from align.
3136 :param bed: BED file of the pangenome coordinates from align.
3137 :param consensus: FASTA file of the pangenome consensus from align.
3138 :param outdir: Output directory.
3139 :param prefix: Prefix for output files.
3140 :param structural: Rtab file of structural variants from structural.
3141 :param core: Core genome threshold for calling core SNPs.
3142 :param indel_window: Exclude SNPs that are within this proximity to indels.
3143 :param snp_window: Exclude SNPs that are within this proximity to another SNP.
3144 :param args: Str of additional arguments [not implemented]
3145 """
3147 alignment_path, bed_path, consensus_path, structural_path, output_dir = alignment, bed, consensus, structural, outdir
3148 args = args if args != None else ""
3149 prefix = f"{prefix}." if prefix != None else ""
3151 # Check output directory
3152 check_output_dir(output_dir)
3154 # -------------------------------------------------------------------------
3155 # Read Pangenome Bed
3157 clusters = {}
3158 logging.info(f"Reading bed: {bed_path}")
3159 with open(bed_path) as infile:
3160 lines = infile.readlines()
3161 for line in tqdm(lines):
3162 line = [l.strip() for l in line.split("\t")]
3163 start, end, name = (int(line[1]), int(line[2]), line[3])
3164 info = {n.split("=")[0]:n.split("=")[1] for n in name.split(";")}
3165 cluster = info["cluster"]
3166 synteny = info["synteny"]
3167 clusters[cluster] = {"start": start, "end": end, "synteny": synteny }
3169 # -------------------------------------------------------------------------
3170 # Read Pangenome Alignment
3172 alignment = {}
3173 all_samples = []
3174 logging.info(f"Reading alignment: {alignment_path}")
3175 with open(alignment_path) as infile:
3176 records = infile.read().split(">")[1:]
3177 for record in tqdm(records):
3178 record_split = record.split("\n")
3179 sample = record_split[0]
3180 if sample not in all_samples:
3181 all_samples.append(sample)
3182 sequence = "".join(record_split[1:]).replace("\n", "").upper()
3183 alignment[sample] = sequence
3185 # -------------------------------------------------------------------------
3186 # Read Consensus Sequence
3188 representative = ""
3189 logging.info(f"Reading consensus sequence: {consensus_path}")
3190 with open(consensus_path) as infile:
3191 record = infile.read().split(">")[1:][0]
3192 record_split = record.split("\n")
3193 representative = "".join(record_split[1:]).replace("\n", "").upper()
3195 alignment_len = len(alignment[all_samples[0]])
3197 # -------------------------------------------------------------------------
3198 # Read Optional Structural Rtab
3200 # Use the structural variants to locate the terminal ends of sequences
3202 structural = OrderedDict()
3203 if structural_path:
3204 logging.info(f"Reading structural: {structural_path}")
3205 with open(structural_path) as infile:
3206 header = infile.readline().strip().split("\t")
3207 samples = header[1:]
3208 lines = infile.readlines()
3209 for line in tqdm(lines):
3210 row = [v.strip() for v in line.split("\t")]
3211 variant = row[0].split("|")
3212 cluster = variant[0]
3213 if cluster not in structural:
3214 structural[cluster] = OrderedDict()
3215 # Cluster_1|structural|1-195_202-240 --> Terminal=(1,240)
3216 coords = variant[1].split(":")[1].split("_")
3217 start = int(coords[0].split("-")[0])
3218 end = int(coords[len(coords) - 1].split("-")[1])
3219 for sample,observation in zip(samples, row[1:]):
3220 if observation == "1":
3221 if sample not in structural[cluster]:
3222 structural[cluster][sample] = []
3223 structural[cluster][sample].append((start, end))
3225 # -------------------------------------------------------------------------
3226 # Extract SNPs from Alignment
3227 # -------------------------------------------------------------------------
3229 logging.info("Extracting SNPs.")
3231 constant_sites = {n:0 for n in NUCLEOTIDES}
3233 snps_data = OrderedDict()
3235 # Iterate over alignment according to cluster
3236 for cluster,cluster_data in tqdm(clusters.items()):
3237 synteny = cluster_data["synteny"]
3238 cluster_start, cluster_stop = cluster_data["start"], cluster_data["end"]
3239 for i in range(cluster_start - 1, cluster_stop):
3240 # Extract nucleotides for all samples
3241 nuc = {s:alignment[s][i] for s in all_samples}
3242 nuc_non_ambig = [n for n in nuc.values() if n in NUCLEOTIDES]
3243 nuc_non_ambig_set = set(nuc_non_ambig)
3244 prop_non_ambig = len(nuc_non_ambig) / len(nuc)
3246 # Option #1. If all missing/ambiguous, skip over
3247 if len(nuc_non_ambig_set) == 0:
3248 continue
3249 # Option #2. All the same/invariant
3250 elif len(nuc_non_ambig_set) == 1:
3251 # record for constant sites if it's a core site
3252 if prop_non_ambig >= core:
3253 n = list(nuc_non_ambig)[0]
3254 constant_sites[n] += 1
3255 continue
3256 # Option #3. Variant, process it further
3257 # Make snp positions 1-based (like VCF)
3258 pangenome_pos = i + 1
3259 cluster_pos = pangenome_pos - cluster_start
3260 # Treat the representative nucleotide as the reference
3261 ref = representative[i]
3262 alt = []
3263 genotypes = []
3264 sample_genotypes = {}
3266 # Filter on multi-allelic and indel proximity
3267 for s,n in nuc.items():
3269 # -------------------------------------------------------------
3270 # Indel proximity checking
3272 if indel_window > 0:
3273 # Check window around SNP for indels
3274 upstream_i = i - indel_window
3275 if upstream_i < cluster_start:
3276 upstream_i = cluster_start
3277 downstream_i = i + indel_window
3278 if downstream_i > (cluster_stop - 1):
3279 downstream_i = (cluster_stop - 1)
3281 # Check if we should truncate the downstream/upstream i by terminals
3282 if cluster in structural and s in structural[cluster]:
3283 for start,stop in structural[cluster][s]:
3284 # Adjust start,stop coordinates from cluster to whole genome
3285 start_i, stop_i = (cluster_start + start) - 1, (cluster_start + stop) - 1
3286 # Original: 226-232, Terminal: 0,230 --> 226-230
3287 if stop_i > upstream_i and stop_i < downstream_i:
3288 downstream_i = stop_i
3289 # Original: 302-308, Terminal: 303,389 --> 303-308
3290 if start_i > upstream_i and start_i < downstream_i:
3291 upstream_i = start_i
3293 context = alignment[s][upstream_i:downstream_i + 1]
3295 # If indel is found nearby, mark this as missing/ambiguous
3296 if "-" in context:
3297 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) in {s} was filtered out due to indel proximity: {context}")
3298 genotypes.append(".")
3299 nuc[s] = "."
3300 sample_genotypes[s] = "./."
3301 continue
3303 # -------------------------------------------------------------
3304 # sample nuc is different than ref
3305 if n != ref:
3306 # handle ambiguous/missing
3307 if n not in NUCLEOTIDES:
3308 genotypes.append(".")
3309 nuc[s] = "."
3310 sample_genotypes[s] = "./."
3311 continue
3313 # add this as a new alt genotype
3314 if n not in alt:
3315 alt.append(n)
3316 sample_genotypes[s] = "1/1"
3317 else:
3318 sample_genotypes[s] = "0/0"
3320 genotypes.append(([ref] + alt).index(n))
3322 # Update our non-ambiguous nucleotides
3323 nuc_non_ambig = [n for n in nuc.values() if n in NUCLEOTIDES]
3324 nuc_non_ambig_set = set(nuc_non_ambig)
3326 genotypes_non_ambig = [g for g in genotypes if g != "."]
3328 # Check if it's still a variant position after indel filtering
3329 if len(alt) == 0 or len(genotypes_non_ambig) == 1:
3330 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) was filtered out as mono-allelic: {ref}")
3331 constant_sites[ref] += 1
3332 continue
3333 # If more than 1 ALT alleles, this is non-biallelic (multi-allelic) so we'll filter it out
3334 elif len(alt) > 1:
3335 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) was filtered out as multi-allelic: {','.join([ref] + alt)}")
3336 continue
3338 alt = list(alt)[0]
3339 # Use "." to indicate missing values (ex. Rtab format)
3340 observations = ["1" if n == alt else "." if n != ref else "0" for n in nuc.values()]
3341 # One more check if it's no longer variant...
3342 observations_non_ambig = set([o for o in observations if o == "1" or o == "0"])
3343 if len(observations_non_ambig) <= 1:
3344 observations_nuc = alt[0] if "1" in observations_non_ambig else ref if "0" in observations_non_ambig else "N"
3345 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) was filtered out as mono-allelic: {observations_nuc}")
3346 continue
3348 # This is a valid SNP! Calculate extra stats
3349 allele_frequency = len([n for n in nuc_non_ambig if n != ref]) / len(nuc)
3351 # Update data
3352 pangenome_snp = f"{ref}{pangenome_pos}{alt}"
3353 cluster_snp = f"{cluster}|snp:{ref}{cluster_pos}{alt}"
3354 snps_data[cluster_snp] = OrderedDict({
3355 "pangenome_snp": pangenome_snp,
3356 "prop_non_ambig" : prop_non_ambig,
3357 "allele_frequency": allele_frequency,
3358 "cluster" : cluster,
3359 "synteny": synteny,
3360 "observations": observations,
3361 "pangenome_pos": pangenome_pos,
3362 "cluster_pos": cluster_pos,
3363 "nuc": nuc,
3364 "ref": ref,
3365 "alt": alt,
3366 "sample_genotypes": sample_genotypes,
3367 })
3369 if snp_window > 0:
3370 snps_exclude = set()
3371 logging.info(f"Filtering out SNPs within {snp_window} bp of each other.")
3372 snps_order = list(snps_data.keys())
3373 for i,snp in enumerate(snps_order):
3374 cluster, coord = snps_data[snp]["cluster"], snps_data[snp]["cluster_pos"]
3375 if i > 0:
3376 prev_snp = snps_order[i-1]
3377 prev_cluster, prev_coord = snps_data[prev_snp]["cluster"], snps_data[prev_snp]["cluster_pos"],
3378 if (cluster == prev_cluster) and (coord - prev_coord <= snp_window):
3379 snps_exclude.add(snp)
3380 snps_exclude.add(prev_snp)
3381 logging.debug(f"SNP {snp} and {prev_snp} are filtered out due to proximity <= {snp_window}.")
3382 snps_data = OrderedDict({snp:data for snp,data in snps_data.items() if snp not in snps_exclude})
3384 # -------------------------------------------------------------------------
3385 # Prepare Outputs
3387 snp_all_alignment = { s:[] for s in all_samples}
3388 snp_core_alignment = { s:[] for s in all_samples}
3389 snp_all_table = open(os.path.join(output_dir, f"{prefix}snps.all.tsv"), 'w')
3390 snp_core_table = open(os.path.join(output_dir, f"{prefix}snps.core.tsv"), 'w')
3391 snp_all_vcf = open(os.path.join(output_dir, f"{prefix}snps.all.vcf"), 'w')
3392 snp_core_vcf = open(os.path.join(output_dir, f"{prefix}snps.core.vcf"), 'w')
3393 snp_rtab_path = os.path.join(output_dir, f"{prefix}snps.Rtab")
3394 snp_rtab = open(snp_rtab_path, 'w')
3396 # TSV Table Header
3397 header = ["snp", "pangenome_snp", "prop_non_ambig", "allele_frequency", "cluster"] + all_samples
3398 snp_all_table.write("\t".join(header) + "\n")
3399 snp_core_table.write("\t".join(header) + "\n")
3401 # Rtab Header
3402 header = ["Variant"] + all_samples
3403 snp_rtab.write("\t".join(header) + "\n")
3405 # VCF Header
3406 all_samples_header = "\t".join(all_samples)
3407 header = textwrap.dedent(
3408 f"""\
3409 ##fileformat=VCFv4.2
3410 ##contig=<ID=pangenome,length={alignment_len}>
3411 ##INFO=<ID=CR,Number=0,Type=Flag,Description="Consensus reference allele, not based on real reference genome">
3412 ##INFO=<ID=C,Number=1,Type=String,Description="Cluster">
3413 ##INFO=<ID=CP,Number=1,Type=Int,Description="Cluster Position">
3414 ##INFO=<ID=S,Number=1,Type=String,Description="Synteny">
3415 ##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
3416 ##FILTER=<ID=IW,Description="Indel window proximity filter {indel_window}">
3417 ##FILTER=<ID=SW,Description="SNP window proximity filter {snp_window}">
3418 #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t{all_samples_header}
3419 """
3420 )
3421 snp_all_vcf.write(header)
3422 snp_core_vcf.write(header)
3424 logging.info(f"Finalizing all and core SNPs ({core}).")
3425 for snp,d in snps_data.items():
3426 # Table
3427 table_row = [snp] + [d["pangenome_snp"], d["prop_non_ambig"], d["allele_frequency"], d["cluster"]] + d["observations"]
3428 snp_all_table.write("\t".join(str(v) for v in table_row) + "\n")
3429 # Alignment
3430 for s in all_samples:
3431 snp_all_alignment[s].append(d["nuc"][s])
3432 # Rtab
3433 rtab_row = [snp] + d["observations"]
3434 snp_rtab.write("\t".join(rtab_row) + "\n")
3435 # VCF
3436 pangenome_pos, cluster_pos = d["pangenome_pos"], d["cluster_pos"]
3437 cluster, synteny = d["cluster"], d["synteny"]
3438 ref, alt = d["ref"], d["alt"]
3439 info = f"CR;C={cluster};CP={cluster_pos},S={synteny}"
3440 genotypes = []
3441 for sample in all_samples:
3442 genotypes.append(d["sample_genotypes"][sample])
3443 vcf_row = ["pangenome", pangenome_pos, snp, ref, alt, ".", "PASS", info, "GT"] + genotypes
3444 vcf_line = "\t".join([str(v) for v in vcf_row])
3445 snp_all_vcf.write(vcf_line + "\n")
3447 # core SNPs
3448 if d["prop_non_ambig"] >= core:
3449 # Table
3450 snp_core_table.write("\t".join(str(v) for v in table_row) + "\n")
3451 for s in all_samples:
3452 snp_core_alignment[s].append(d["nuc"][s])
3453 # VCF
3454 snp_core_vcf.write(vcf_line + "\n")
3456 # -------------------------------------------------------------------------
3457 # Write SNP fasta alignments
3459 snp_all_alignment_path = os.path.join(output_dir, f"{prefix}snps.all.fasta")
3460 logging.info(f"Writing all SNP alignment: {snp_all_alignment_path}")
3461 with open(snp_all_alignment_path, 'w') as outfile:
3462 for sample in tqdm(all_samples):
3463 sequence = "".join(snp_all_alignment[sample])
3464 outfile.write(f">{sample}\n{sequence}\n")
3466 snp_core_alignment_path = os.path.join(output_dir, f"{prefix}snps.core.fasta")
3467 logging.info(f"Writing {core} core SNP alignment: {snp_core_alignment_path}")
3468 with open(snp_core_alignment_path, 'w') as outfile:
3469 for sample in tqdm(all_samples):
3470 sequence = "".join(snp_core_alignment[sample])
3471 outfile.write(f">{sample}\n{sequence}\n")
3473 # -------------------------------------------------------------------------
3474 # Write constant sites
3476 constant_sites_path = os.path.join(output_dir, f"{prefix}snps.constant_sites.txt")
3477 logging.info(f"Writing constant sites: {constant_sites_path}")
3478 with open(constant_sites_path, 'w') as outfile:
3479 line = ",".join([str(v) for v in list(constant_sites.values())])
3480 outfile.write(line + "\n")
3482 # -------------------------------------------------------------------------
3483 # Cleanup
3485 snp_all_table.close()
3486 snp_core_table.close()
3487 snp_rtab.close()
3488 snp_all_vcf.close()
3489 snp_core_vcf.close()
3491 return snp_rtab_path
3494def presence_absence(
3495 clusters: str,
3496 outdir: str = ".",
3497 prefix: str = None,
3498 args: str = None
3499 ):
3500 """
3501 Extract presence absence of summarized clusters.
3503 Takes as input the TSV file of summarized clusters from summarize.
3504 Outputs an Rtab file of cluster presence/absence.
3506 Examples:
3507 >>> presence_absence(clusters="summarize.clusters.tsv")
3509 :param clusters: Path to TSV of summarized clusters from summarize.
3510 :param outdir: Output directory.
3511 :param prefix: Prefix for output files.
3512 :param args: Str of addtional arguments [not implemented].
3513 """
3515 clusters_path, output_dir = clusters, outdir
3516 args = args if args != None else ""
3517 prefix = f"{prefix}." if prefix != None else ""
3519 # Check output directory
3520 check_output_dir(output_dir)
3522 all_samples = []
3524 # -------------------------------------------------------------------------
3525 # Read Clusters
3527 logging.info(f"Reading summarized clusters: {clusters_path}")
3528 variants = OrderedDict()
3530 with open(clusters_path) as infile:
3531 header = [line.strip() for line in infile.readline().split("\t")]
3532 # sample columns begin after dbxref_alt
3533 all_samples = header[header.index("dbxref_alt")+1:]
3534 lines = infile.readlines()
3535 for line in tqdm(lines):
3536 row = [l.strip() for l in line.split("\t")]
3537 data = {k:v for k,v in zip(header, row)}
3538 cluster = data["cluster"]
3539 variants[cluster] = []
3541 # Get samples with sequences for this cluster
3542 for sample in all_samples:
3543 sequence_ids = data[sample].split(",") if data[sample] != "" else []
3544 if len(sequence_ids) > 0:
3545 variants[cluster].append(sample)
3546 # If the cluster is all samples, it's not "variable"
3547 if len(variants[cluster]) == len(all_samples):
3548 del variants[cluster]
3550 # -------------------------------------------------------------------------
3551 # Write Presence Absence Rtab
3553 rtab_path = os.path.join(output_dir, f"{prefix}presence_absence.Rtab")
3554 logging.info(f"Writing variants: {rtab_path}")
3556 with open(rtab_path, 'w') as outfile:
3557 header = ["Variant"] + all_samples
3558 outfile.write("\t".join(header) + "\n")
3559 for variant,samples in tqdm(variants.items()):
3560 variant = f"{variant}|presence_absence"
3561 observations = ["1" if s in samples else "0" for s in all_samples]
3562 line = "\t".join([variant] + observations)
3563 outfile.write(line + "\n")
3565 return rtab_path
3568def combine(
3569 rtab: list,
3570 outdir: str = ".",
3571 prefix: str = None,
3572 args: str = None,
3573 ):
3574 """
3575 Combine variants from multiple Rtab files.
3577 Takes as input a list of file paths to Rtab files. Outputs an Rtab file with
3578 the variants concatenated, ensuring consistent ordering of the sample columns.
3580 >>> combine(rtab=["snps.Rtab", "structural.Rtab", "presence_absence.Rtab"])
3582 :param rtab: List of paths to input Rtab files to combine.
3583 :param outdir: Output directory.
3584 :param prefix: Prefix for output files.
3585 :param args: Str of additional arguments [not implemented].
3586 """
3588 output_dir = outdir
3589 prefix = f"{prefix}." if prefix != None else ""
3590 # Check output directory
3591 check_output_dir(output_dir)
3593 all_samples = []
3594 rtab_path = os.path.join(output_dir, f"{prefix}combine.Rtab")
3595 logging.info(f"Writing combined Rtab: {rtab_path}")
3596 with open(rtab_path, 'w') as outfile:
3597 for file_path in rtab:
3598 if not file_path: continue
3599 logging.info(f"Reading variants: {file_path}")
3600 with open(file_path) as infile:
3601 header = infile.readline().strip().split("\t")
3602 samples = header[1:]
3603 if len(all_samples) == 0:
3604 all_samples = samples
3605 outfile.write("\t".join(["Variant"] + all_samples) + "\n")
3606 # Organize the sample order consistently
3607 lines = infile.readlines()
3608 for line in tqdm(lines):
3609 row = [v.strip() for v in line.split("\t")]
3610 variant = row[0]
3611 observations = {k:v for k,v in zip(samples, row[1:])}
3612 out_row = [variant] + [observations[s] for s in all_samples]
3613 outfile.write("\t".join(out_row) + "\n")
3615 return rtab_path
3618def root_tree(
3619 tree: str,
3620 outgroup: list = [],
3621 tree_format: str = "newick",
3622 outdir: str = ".",
3623 prefix: str = None,
3624 args: str = None,
3625 ):
3626 """
3627 Root a tree on outgroup taxa.
3628 """
3630 tree_path, output_dir = tree, outdir
3631 prefix = f"{prefix}." if prefix != None else ""
3632 # Check output directory
3633 check_output_dir(output_dir)
3635 if type(outgroup) == str:
3636 outgroup = outgroup.split(",")
3637 elif type(outgroup) != list:
3638 msg = f"The outgroup must be either a list or CSV string: {outgroup}"
3639 logging.error(msg)
3640 raise Exception(msg)
3642 # Prioritize tree path
3643 logging.info(f"Reading tree: {tree_path}")
3644 tree = read_tree(tree_path, tree_format=tree_format)
3645 tree.is_rooted = True
3646 # Cleanup node/taxon labels
3647 for node in list(tree.preorder_node_iter()):
3648 if node.taxon:
3649 node.label = str(node.taxon)
3650 node.label = node.label.replace("\"", "").replace("'", "") if node.label != None else None
3652 logging.info(f"Rooting tree with outgroup: {outgroup}")
3654 tree_tips = [n.label for n in tree.leaf_node_iter()]
3655 if sorted(tree_tips) == sorted(outgroup):
3656 msg = "Your outgroup contains all taxa in the tree."
3657 logging.error(msg)
3658 raise Exception(msg)
3660 # sort the outgroup labels for easier checks later
3661 outgroup = sorted(outgroup)
3663 # Case 1: No outgroup requested, use first tip
3664 if len(outgroup) == 0:
3665 outgroup_node = [n for n in tree.leaf_node_iter()][0]
3666 logging.info(f"No outgroup requested, rooting on first tip: {outgroup_node.label}")
3667 edge_length = outgroup_node.edge_length / 2
3668 # Case 2: 1 or more outgroups requested
3669 else:
3670 # Case 2a: find a clade that is only the outgroup
3671 logging.info(f"Searching for outgroup clade: {outgroup}")
3672 outgroup_node = None
3673 for node in tree.nodes():
3674 node_tips = sorted([n.label for n in node.leaf_iter()])
3675 if node_tips == outgroup:
3676 logging.info(f"Outgroup clade found.")
3677 outgroup_node = node
3678 break
3679 # Case 2b: find a clade that is only the ingroup
3680 if outgroup_node == None:
3681 logging.info(f"Searching for ingroup clade that does not include: {outgroup}")
3682 for node in tree.nodes():
3683 node_tips = sorted([n.label for n in node.leaf_iter() if n.label in outgroup])
3684 if len(node_tips) == 0:
3685 # If this is a taxon, make sure the outgroup is complete
3686 if node.taxon != None:
3687 remaining_taxa = sorted([t for t in tree_tips if t != node.label])
3688 if remaining_taxa != outgroup:
3689 continue
3690 logging.info(f"Ingroup clade found.")
3691 outgroup_node = node
3692 break
3693 # If we couldn't find an outgroup, it's probably not monophyletic
3694 # If we created the tree with the same outgroup in IQTREE, this
3695 # shouldn't be possible, but we'll handle just in case
3696 if outgroup_node == None:
3697 msg = "Failed to find the outgroup clade. Are you sure your outgroup is monophyletic?"
3698 logging.error(msg)
3699 raise Exception(msg)
3700 edge_length = outgroup_node.edge_length / 2
3702 tree.reroot_at_edge(outgroup_node.edge, update_bipartitions=True, length1=edge_length, length2=edge_length )
3704 rooted_path = os.path.join(output_dir, f"{prefix}rooted.{tree_format}")
3705 logging.info(f"Writing rooted tree: {rooted_path}")
3706 # Suppress the [&R] prefix, which causes problems in downstream programs
3707 tree.write(path=rooted_path, schema=tree_format, suppress_rooting=True)
3709 return rooted_path
3711def root_tree_midpoint(tree):
3712 """
3713 Reroots the tree at the midpoint of the longest distance between
3714 two taxa in a tree.
3716 This is a modification of dendropy's reroot_at_midpoint function:
3717 https://github.com/jeetsukumaran/DendroPy/blob/49cf2cc2/src/dendropy/datamodel/treemodel/_tree.py#L2616
3719 This is a utility function that is used by gwas for the kinship
3720 similarity matrix.
3722 :param tree: dendropy.datamodel.treemodel._tree.Tree
3723 :return: dendropy.datamodel.treemodel._tree.Tree
3724 """
3725 from dendropy.calculate.phylogeneticdistance import PhylogeneticDistanceMatrix
3726 pdm = PhylogeneticDistanceMatrix.from_tree(tree)
3728 tax1, tax2 = pdm.max_pairwise_distance_taxa()
3729 plen = float(pdm.patristic_distance(tax1, tax2)) / 2
3731 n1 = tree.find_node_with_taxon_label(tax1.label)
3732 n2 = tree.find_node_with_taxon_label(tax2.label)
3734 # Find taxa furthest from the root
3735 mrca = pdm.mrca(tax1, tax2)
3736 n1_dist, n2_dist = n1.distance_from_root(), n2.distance_from_root()
3738 # Start at furthest taxa, walk up to root, look for possible edge
3739 # representing midpoint
3740 node = n1 if n1_dist >= n2_dist else n2
3741 remaining_len = plen
3743 while node != mrca:
3744 edge_length = node.edge.length
3745 remaining_len -= edge_length
3746 # Found midpoint node
3747 if remaining_len == 0:
3748 node = node._parent_node
3749 logging.info(f"Rerooting on midpoint node: {node}")
3750 tree.reroot_at_node(node, update_bipartitions=True)
3751 break
3752 # Found midpoint edge
3753 elif remaining_len < 0:
3754 l2 = edge_length + remaining_len
3755 l1 = edge_length - l2
3756 logging.info(f"Rerooting on midpoint edge, l1={l1}, l2={l2}")
3757 tree.reroot_at_edge(node.edge, update_bipartitions=True, length1=l1, length2=l2 )
3758 break
3759 node = node._parent_node
3761 return tree
3764def read_tree(tree, tree_format:str="newick"):
3765 """Read a tree from file into an object, and clean up labels."""
3766 from dendropy import Tree
3768 tree = Tree.get(path=tree, schema=tree_format, tree_offset=0, preserve_underscores=True)
3769 tree.is_rooted = True
3771 # Give tip nodes the name of their taxon, cleanup tip labels
3772 for node in list(tree.preorder_node_iter()):
3773 if node.taxon:
3774 node.label = str(node.taxon)
3775 node.label = node.label.replace("\"", "").replace("'", "") if node.label != None else None
3777 return tree
3780def tree(
3781 alignment: str,
3782 outdir: str = ".",
3783 prefix: str = None,
3784 threads: int = 1,
3785 constant_sites: str = None,
3786 args: str = TREE_ARGS,
3787 ):
3788 """
3789 Estimate a maximum-likelihood tree with IQ-TREE.
3791 Takes as input a multiple sequence alignment in FASTA format. If a SNP
3792 alignment is provided, an optional text file of constant sites can be
3793 included for correction. Outputs a maximum-likelihood tree, as well as
3794 additional rooted trees if an outgroup is specified in the iqtree args.
3796 tree(alignment="snps.core.fasta", constant_sites="snps.constant_sites.txt")
3797 tree(alignment="pangenome.aln", threads=4, args='--ufboot 1000 -o sample1')
3799 :param alignment: Path to multiple sequence alignment for IQ-TREE.
3800 :param outdir: Output directory.
3801 :param prefix: Prefix for output files.
3802 :param threads: CPU threads for IQ-TREE.
3803 :param constant_sites: Path to text file of constant site corrections.
3804 :param args: Str of additional arguments for IQ-TREE.
3805 """
3807 alignment_path, constant_sites_path, output_dir = alignment, constant_sites, outdir
3808 args = args if args != None else ""
3809 prefix = f"{prefix}." if prefix != None else ""
3811 # Check output directory
3812 check_output_dir(output_dir)
3814 # Check for seed and threads conflict
3815 if int(threads) > 1 and "--seed" in args:
3816 logging.warning(
3817 f"""
3818 Using multiple threads is not guaranteed to give reproducible
3819 results, even when using the --seed argument:
3820 See this issue for more information: https://github.com/iqtree/iqtree2/discussions/233
3821 """)
3822 args += f" -T {threads}"
3824 # Check for outgroup rooting
3825 if "-o " in args:
3826 outgroup = extract_cli_param(args, "-o")
3827 logging.info(f"Tree will be rooted on outgroup: {outgroup}")
3828 outgroup_labels = outgroup.split(",")
3829 else:
3830 outgroup = None
3831 outgroup_labels = []
3833 # Check for constant sites input for SNP
3834 if constant_sites_path:
3835 with open(constant_sites_path) as infile:
3836 constant_sites = infile.read().strip()
3837 constant_sites_arg = f"-fconst {constant_sites}"
3838 else:
3839 constant_sites_arg = ""
3841 # -------------------------------------------------------------------------
3842 # Estimate maximum likelihood tree
3843 iqtree_prefix = os.path.join(output_dir, f"{prefix}tree")
3844 run_cmd(f"iqtree -s {alignment_path} --prefix {iqtree_prefix} {constant_sites_arg} {args}")
3845 tree_path = f"{iqtree_prefix}.treefile"
3847 # Check for outgroup warnings in log
3848 log_path = f"{iqtree_prefix}.log"
3849 with open(log_path) as infile:
3850 for line in infile:
3851 outgroup_msg = "Branch separating outgroup is not found"
3852 if outgroup_msg in line:
3853 msg = f"Branch separating outgroup is not found. Please check if your ({outgroup}) is monophyletic in: {tree_path}"
3854 logging.error(msg)
3855 raise Exception(msg)
3857 # -------------------------------------------------------------------------
3858 # Fix root position
3860 tree_path = root_tree(tree=tree_path, outgroup=outgroup_labels, outdir=outdir, prefix=f"{prefix}tree")
3861 tree = read_tree(tree_path, tree_format="newick")
3862 tips = sorted([n.label for n in tree.leaf_node_iter()])
3863 # Sort child nodes, this extra sort should help us
3864 # organize tip nodes by their branch length?
3865 for node in tree.preorder_node_iter():
3866 node._child_nodes.sort(
3867 key=lambda node: getattr(getattr(node, "taxon", None), "label", ""),
3868 reverse=True
3869 )
3871 tree_path = f"{iqtree_prefix}.rooted.nwk"
3872 logging.info(f"Writing rooted tree: {tree_path}")
3873 # Suppress the [&R] prefix, which causes problems in downstream programs
3874 tree.write(path=tree_path, schema="newick", suppress_rooting=True)
3876 # -------------------------------------------------------------------------
3877 # Tidy branch supports
3879 node_i = 0
3881 branch_support_path = tree_path.replace(".treefile", ".branch_support.tsv").replace(".nwk", ".branch_support.tsv")
3882 logging.info(f"Extracting branch supports to table: {branch_support_path}")
3883 with open(branch_support_path, 'w') as outfile:
3884 header = ["node", "original", "ufboot", "alrt", "significant"]
3885 outfile.write('\t'.join(header) + '\n')
3886 for n in tree.preorder_node_iter():
3887 # Sanity check quotations
3888 if n.label:
3889 n.label = n.label.replace("'", "")
3890 if not n.is_leaf():
3891 if n.label:
3892 original = n.label
3893 else:
3894 original = "NA"
3895 support = n.label
3896 significant = ""
3897 ufboot = "NA"
3898 alrt = "NA"
3899 if support:
3900 if "/" in support:
3901 ufboot = float(support.split("/")[0])
3902 alrt = float(support.split("/")[1])
3903 if ufboot > 95 and alrt > 80:
3904 significant = "*"
3905 else:
3906 ufboot = float(support.split("/")[0])
3907 alrt = "NA"
3908 if ufboot > 95:
3909 significant = "*"
3910 n.label = f"NODE_{node_i}"
3911 row = [n.label, str(original), str(ufboot), str(alrt), significant]
3912 outfile.write('\t'.join(row) + '\n')
3913 node_i += 1
3915 # Tree V1: Branch support is replaced by node labels
3916 node_labels_path = tree_path.replace(".treefile", ".labelled_nodes.nwk").replace(".nwk", ".labelled_nodes.nwk")
3917 logging.info(f"Writing tree with node labels: {node_labels_path}")
3918 tree.write(path=node_labels_path, schema="newick", suppress_rooting=True)
3920 # Tree V2: No branch support or node labels (plain)
3921 plain_path = node_labels_path.replace(".labelled_nodes.nwk", ".plain.nwk")
3922 logging.info(f"Writing plain tree: {plain_path}")
3923 for n in tree.preorder_node_iter():
3924 n.label = None
3925 tree.write(path=plain_path, schema="newick", suppress_rooting=True)
3926 return tree_path
3929def get_delim(table:str):
3930 if table.endswith(".tsv") or table.endswith(".Rtab"):
3931 return "\t"
3932 elif table.endswith(".csv"):
3933 return ","
3934 else:
3935 msg = f"Unknown file extension of table: {table}"
3936 logging.error(msg)
3937 raise Exception(msg)
3939def table_to_rtab(
3940 table: str,
3941 filter: str,
3942 outdir: str = ".",
3943 prefix: str = None,
3944 args: str = None,
3945 ):
3946 """
3947 Convert a TSV/CSV table to an Rtab file based on regex filters.
3949 Takes as input a TSV/CSV table to convert, and a TSV/CSV of regex filters.
3950 The filter table should have a header with the fields: column, regex, name. Where column
3951 is the 'column' to search, 'regex' is the regular expression pattern, and
3952 'name' is how the output variant should be named in the Rtab.
3954 An example `filter.tsv` might look like this:
3956 column regex name extra
3957 assembly .*sample2.* sample2 data that will be ignored
3958 lineage .*2.* lineage_2 more data
3960 Where the goal is to filter the assembly and lineage columns for particular values.
3962 >>> table_to_rtab(table="samplesheet.csv", filter="filter.tsv")
3963 """
3965 table_path, filter_path, output_dir = table, filter, outdir
3966 prefix = f"{prefix}." if prefix != None else ""
3968 check_output_dir(output_dir)
3970 logging.info(f"Checking delimiters.")
3971 table_delim = get_delim(table_path)
3972 filter_delim = get_delim(filter_path)
3974 logging.info(f"Reading filters: {filter_path}")
3975 filters = OrderedDict()
3976 with open(filter_path) as infile:
3977 header = [h.strip() for h in infile.readline().split(filter_delim)]
3978 col_i, regex_i, name_i = [header.index(c) for c in ["column", "regex", "name"]]
3979 for line in infile:
3980 row = [l.strip() for l in line.split(filter_delim)]
3981 column,regex,name = row[col_i], row[regex_i], row[name_i]
3983 logging.debug(f"name: {name}, column: {column}, regex: {regex}")
3984 if name in filters:
3985 msg = f"Filter name {name} is not unique."
3986 logging.error(msg)
3987 raise Exception(msg)
3988 filters[name] = {k:v for k,v in zip(header, row)}
3990 logging.info(f"Searching input table: {table_path}")
3991 rtab_data = OrderedDict({name:{} for name in filters})
3992 all_samples = []
3993 with open(table_path) as infile:
3994 header = [h.strip() for h in infile.readline().split(table_delim)]
3995 # Locate filter columns in the file
3996 filter_names = list(filters.keys())
3997 for name in filter_names:
3998 column = filters[name]["column"]
3999 if column in header:
4000 filters[name]["i"] = header.index(column)
4001 else:
4002 logging.warning(f"Filter column {column} is not present in the table.")
4003 del filters[name]
4004 if len(filters) == 0:
4005 msg = "No filters were found matching columns in the input table."
4006 logging.error(msg)
4007 raise Exception(msg)
4008 # Parse the table lines
4009 for line in infile:
4010 row = [l.strip() for l in line.split(table_delim)]
4011 sample = row[0]
4012 if sample not in all_samples:
4013 all_samples.append(sample)
4014 for name,data in filters.items():
4015 regex = data["regex"]
4016 text = row[data["i"]]
4017 val = 1 if re.match(regex, text) else 0
4018 rtab_data[name][sample] = val
4020 rtab_path = os.path.join(output_dir, f"{prefix}output.Rtab")
4021 extended_path = os.path.join(output_dir, f"{prefix}output.tsv")
4022 output_delim = "\t"
4023 logging.info(f"Writing Rtab output: {rtab_path}")
4024 logging.info(f"Writing extended output: {extended_path}")
4025 with open(rtab_path, 'w') as rtab_outfile:
4026 header = ["Variant"] + [str(s) for s in all_samples]
4027 rtab_outfile.write(output_delim.join(header) + "\n")
4028 with open(extended_path, 'w') as extended_outfile:
4029 first_filter = list(filters.keys())[0]
4030 header = ["sample"] + [col for col in filters[first_filter].keys() if col != "i"]
4031 extended_outfile.write(output_delim.join(header) + "\n")
4033 for name,data in rtab_data.items():
4034 row = [name] + [str(data[s]) for s in all_samples]
4035 line = output_delim.join([str(val) for val in row])
4036 rtab_outfile.write(line + "\n")
4038 positive_samples = [s for s,o in data.items() if str(o) == "1"]
4039 for sample in positive_samples:
4040 row = [sample] + [v for k,v in filters[name].items() if k != "i"]
4041 line = output_delim.join([str(val) for val in row])
4042 extended_outfile.write(line + "\n")
4044def vcf_to_rtab(
4045 vcf: str,
4046 bed: str = None,
4047 outdir: str = ".",
4048 prefix: str = None,
4049 args: str = None,
4050 ):
4051 """
4052 Convert a VCF file to an Rtab file.
4054 >>> vcf_to_rtab(vcf="snps.csv")
4055 """
4057 vcf_path, bed_path, output_dir = vcf, bed, outdir
4058 prefix = f"{prefix}." if prefix != None else ""
4059 check_output_dir(output_dir)
4061 bed = OrderedDict()
4062 if bed_path != None:
4063 logging.info(f"Reading BED: {bed_path}")
4064 with open(bed_path) as infile:
4065 lines = infile.readlines()
4066 for line in tqdm(lines):
4067 row = [l.strip() for l in line.split("\t")]
4068 _chrom, start, end, name = row[0], row[1], row[2], row[3]
4069 bed[name] = {"start": int(start) + 1, "end": int(end) + 1}
4071 logging.info(f"Reading VCF: {vcf_path}")
4073 all_samples = []
4075 rtab_path = os.path.join(output_dir, f"{prefix}output.Rtab")
4077 with open(vcf_path) as infile:
4078 with open(rtab_path, 'w') as outfile:
4079 logging.info(f"Counting lines in VCF.")
4080 lines = infile.readlines()
4082 logging.info(f"Parsing variants.")
4083 for line in tqdm(lines):
4084 if line.startswith("##"): continue
4085 # Write the Rtab header
4086 if line.startswith("#"):
4087 header = [l.strip().replace("#", "") for l in line.split("\t")]
4088 # Samples start after 'FORMAT'
4089 samples_i = header.index("FORMAT") + 1
4090 all_samples = header[samples_i:]
4091 line = "\t".join(["Variant"] + [str(s) for s in all_samples])
4092 outfile.write(line + "\n")
4093 continue
4095 row = [l.strip() for l in line.split("\t")]
4096 data = {k:v for k,v in zip(header, row)}
4097 chrom, pos = f"{data['CHROM']}", int(data['POS'])
4098 ref = data['REF']
4099 alt = [nuc for nuc in data['ALT'].split(",") if nuc != "."]
4101 # Skip multiallelic/missing
4102 if len(alt) != 1: continue
4104 alt = alt[0]
4105 observations = []
4106 for sample in all_samples:
4107 genotype = data[sample]
4108 if genotype == "0/0":
4109 value = "0"
4110 elif genotype == "1/1":
4111 value = "1"
4112 else:
4113 value = "."
4115 observations.append(value)
4117 for name,region in bed.items():
4118 start, end = region["start"], region["end"]
4119 if pos >= start and pos < end:
4120 chrom = name
4121 pos = (pos - start) + 1
4122 break
4124 variant = f"{chrom}|snp:{ref}{pos}{alt}"
4126 line = "\t".join([variant] + observations)
4127 outfile.write(line + "\n")
4129def binarize(
4130 table: str,
4131 column: str,
4132 outdir: str = ".",
4133 prefix: str = None,
4134 output_delim: str = "\t",
4135 column_prefix: str = None,
4136 transpose: bool = False,
4137 ):
4138 """
4139 Convert a categorical column to multiple binary (0/1) columns.
4141 Takes as input a table (tsv or csv) as well as a column to binarize.
4143 >>> binarize(table="samplesheet.csv", column="lineage", output_delim=",")
4144 >>> binarize(table="samplesheet.csv", column="resistant", output_delim="\t", transpose=True)
4146 :param table: Path to table (TSV or CSV)
4147 :param column: Name of column in table.
4148 :param outdir: Path to the output directory.
4149 """
4151 table_path, output_dir = table, outdir
4152 column_prefix = column_prefix if column_prefix else ""
4153 prefix = f"{prefix}." if prefix != None else ""
4154 output_ext = "tsv" if output_delim == "\t" else "csv" if output_delim == "," else "txt"
4156 # Check output directory
4157 check_output_dir(output_dir)
4158 input_delim = get_delim(table_path)
4160 logging.info(f"Reading table: {table_path}")
4162 all_samples = []
4163 values = {}
4165 with open(table_path) as infile:
4166 header = infile.readline().strip().split(input_delim)
4167 sample_i = 0
4168 column_i = header.index(column)
4169 for line in infile:
4170 line = [v.strip() for v in line.split(input_delim)]
4171 sample, val = line[sample_i], line[column_i]
4172 if sample not in all_samples:
4173 all_samples.append(sample)
4174 if val == "": continue
4175 if val not in values:
4176 values[val] = []
4177 values[val].append(sample)
4179 # Sort values
4180 values = OrderedDict({v:values[v] for v in sorted(values.keys())})
4182 binarize_path = os.path.join(output_dir, f"{prefix}binarize.{output_ext}")
4183 logging.info(f"Writing binarized output: {binarize_path}")
4184 with open(binarize_path, 'w') as outfile:
4185 first_col = column if transpose else header[sample_i]
4186 if transpose:
4187 out_header = [first_col] + all_samples
4188 outfile.write(output_delim.join(out_header) + "\n")
4189 for val,samples in values.items():
4190 observations = ["1" if sample in samples else "0" for sample in all_samples]
4191 row = [f"{column_prefix}{val}"] + observations
4192 outfile.write(output_delim.join(row) + "\n")
4193 else:
4194 out_header = [first_col] + [f"{column_prefix}{k}" for k in values.keys()]
4195 outfile.write(output_delim.join(out_header) + "\n")
4196 for sample in all_samples:
4197 row = [sample]
4198 for val,val_samples in values.items():
4199 row.append("1") if sample in val_samples else row.append("0")
4200 outfile.write(output_delim.join(row) + "\n")
4203def qq_plot(
4204 locus_effects:str,
4205 outdir: str = ".",
4206 prefix: str = None,
4207 ):
4208 """
4209 Modified version of pyseer's qq_plot function.
4211 Source: https://github.com/mgalardini/pyseer/blob/master/scripts/qq_plot.py
4212 """
4214 import matplotlib.pyplot as plt
4215 import numpy as np
4216 import statsmodels
4217 import statsmodels.api as sm
4218 import pandas as pd
4220 locus_effects_path, output_dir = locus_effects, outdir
4221 prefix = f"{prefix}." if prefix != None else ""
4223 logging.info(f"Reading locus effects: {locus_effects_path}")
4224 lrt_pvalues = []
4225 with open(locus_effects_path) as infile:
4226 header = [l.strip() for l in infile.readline().split("\t")]
4227 for line in infile:
4228 row = [l.strip() for l in line.split("\t")]
4229 data = {k:v for k,v in zip(header, row)}
4230 lrt_pvalue = data["lrt-pvalue"]
4231 if lrt_pvalue != "":
4232 lrt_pvalues.append(float(lrt_pvalue))
4234 # Exclude p-value=0, which has no log
4235 y = -np.log10([val for val in lrt_pvalues if val != 0])
4236 x = -np.log10(np.random.uniform(0, 1, len(y)))
4238 # check for statsmodels version (issue #212)
4239 old_stats = False
4240 try:
4241 vmajor, vminor = statsmodels.__version__.split('.')[:2]
4242 if int(vmajor) == 0 and int(vminor) < 13:
4243 old_stats = True
4244 else:
4245 old_stats = False
4246 except:
4247 msg = "Failed to identify the statsmodel version, QQ plot will not be created."
4248 logging.warning(msg)
4249 return None
4251 if old_stats:
4252 xx = y
4253 yy = x
4254 else:
4255 xx = x
4256 yy = y
4258 # Plot
4259 logging.info(f"Creating QQ plot.")
4260 plt.figure(figsize=(4, 3.75))
4261 ax = plt.subplot(111)
4263 fig = sm.qqplot_2samples(xx,
4264 yy,
4265 xlabel='Expected $-log_{10}(pvalue)$',
4266 ylabel='Observed $-log_{10}(pvalue)$',
4267 line='45',
4268 ax=ax)
4270 ax = fig.axes[0]
4271 ax.lines[0].set_color('k')
4272 ax.lines[0].set_alpha(0.3)
4274 # Handle inf values if a p-value was 0
4275 x_max = x.max() if not np.isinf(x.max()) else max([val for val in x if not np.isinf(val)])
4276 y_max = y.max() if not np.isinf(y.max()) else max([val for val in y if not np.isinf(val)])
4277 ax.set_xlim(-0.5, x_max + 0.5)
4278 ax.set_ylim(-0.5, y_max + 0.5)
4280 plt.tight_layout()
4282 plot_path = os.path.join(output_dir, f"{prefix}qq_plot.png")
4283 logging.info(f"Writing plot: {plot_path}")
4284 plt.savefig(plot_path, dpi=150)
4285 plt.close('all')
4287 return plot_path
4290def gwas(
4291 variants: str,
4292 table: str,
4293 column: str,
4294 clusters: str = None,
4295 continuous: str = False,
4296 tree: str = None,
4297 midpoint: bool= True,
4298 outdir: str = ".",
4299 prefix: str = None,
4300 lineage_column: str = None,
4301 exclude_missing: bool = False,
4302 threads: int = 1,
4303 args: str = GWAS_ARGS,
4304 ):
4305 """
4306 Run genome-wide association study (GWAS) tests with pyseer.
4308 Takes as input the TSV file of summarized clusters, an Rtab file of variants,
4309 a TSV/CSV table of phenotypes, and a column name representing the trait of interest.
4310 Outputs tables of locus effects, and optionally lineage effects (bugwas) if specified.
4312 >>> gwas(clusters='summarize.clusters.tsv', variants='combine.Rtab', table='samplesheet.csv', column='resistant')
4313 >>> gwas(clusters='summarize.clusters.tsv', variants='combine.Rtab', table='samplesheet.csv', column='lineage', args='--no-distances')
4314 >>> gwas(clusters='summarize.clusters.tsv', variants='combine.Rtab', table='samplesheet.csv', column='resistant', args='--lmm --min-af 0.05')
4316 :param variants: Path to Rtab file of variants.
4317 :param table: Path to TSV/CSV table of traits.
4318 :param column: Column name of variable in table.
4319 :param clusters: Path to TSV file of summarized clusters.
4320 :param continuous: Treat column as a continuous variable.
4321 :param tree: Path to newick tree.
4322 :param midpoint: True if the tree should be rerooted at the midpoint.
4323 :param outdir: Output directory.
4324 :param prefix: Prefix for output files.
4325 :param lineage_column: Name of lineage column in table (enables bugwas lineage effects).
4326 :param exclude_missing: Exclude samples missing phenotype data.
4327 :param threads: CPU threads for pyseer.
4328 :param args: Str of additional arguments for pyseer.
4329 """
4331 clusters_path, variants_path, table_path, tree_path, output_dir = clusters, variants, table, tree, outdir
4333 # Check output directory
4334 check_output_dir(output_dir)
4336 # Read in the table that contains the phenotypes (optional lineage column)
4337 logging.info(f"Reading table: {table_path}")
4338 with open(table_path) as infile:
4339 table = infile.read().replace(",", "\t")
4340 table_split = table.split("\n")
4341 table_header = table_split[0].split("\t")
4342 table_rows = [[l.strip() for l in line.split("\t") ] for line in table_split[1:] if line.strip() != ""]
4344 # Is the user didn't provide any custom args initialize as empty
4345 args = args if args != None else ""
4346 args += f" --cpu {threads}"
4347 prefix = f"{prefix}." if prefix != None else ""
4349 # Check for conflicts between input arguments
4350 if "--no-distances" not in args:
4351 if tree_path == None and "--lmm" in args:
4352 msg = "You must supply a phylogeny if you have not requested --no-distances."
4353 logging.error(msg)
4354 raise Exception(msg)
4355 else:
4356 if lineage_column:
4357 msg = "The param --no-distances cannot be used if a lineage column was specified."
4358 logging.error(msg)
4359 raise Exception(msg)
4361 if "--wg enet" in args:
4362 msg = "The whole genome elastic net model is not implemented yet."
4363 logging.error(msg)
4364 raise Exception(msg)
4366 # If a lineage column was specified, extract it from the table into
4367 # it's own file for pyseer.
4368 lineage_path = None
4369 sample_i = 0
4371 if lineage_column:
4372 lineage_path = os.path.join(output_dir, f"{prefix}{column}.lineages.tsv")
4373 logging.info(f"Extracting lineage column {lineage_column}: {lineage_path}")
4374 with open(lineage_path, 'w') as lineage_file:
4375 lineage_i = table_header.index(lineage_column)
4376 for row in table_rows:
4377 sample,lineage = row[sample_i], row[lineage_i]
4378 lineage_file.write(f"{sample}\t{lineage}\n")
4380 # -------------------------------------------------------------------------
4381 # Table / Phenotype
4383 if column not in table_header:
4384 msg = f"Column {column} is not present in the table: {table_path}"
4385 logging.error(msg)
4386 raise Exception(msg)
4388 column_i = table_header.index(column)
4390 # Exclude samples that are missing phenotype if requested
4391 exclude_samples = set()
4392 if exclude_missing:
4393 logging.info("Excluding samples missing phenotype data.")
4394 table_filter = []
4395 for row in table_rows:
4396 sample = row[sample_i]
4397 value = row[column_i]
4398 if value == "NA" or value == "":
4399 exclude_samples.add(sample)
4400 continue
4401 table_filter.append(row)
4402 table_rows = table_filter
4403 logging.info(f"Number of samples excluded: {len(exclude_samples)}")
4405 # Convert the phenotypes table from csv to tsv for pyseer
4406 if table_path.endswith(".csv"):
4407 logging.info(f"Converting csv to tsv for pyseer: {table_path}")
4408 file_name = os.path.basename(table_path)
4409 table_tsv_path = os.path.join(output_dir, f"{prefix}{column}." + file_name.replace(".csv", ".tsv"))
4410 with open(table_tsv_path, 'w') as outfile:
4411 outfile.write("\t".join(table_header) + "\n")
4412 for row in table_rows:
4413 outfile.write("\t".join(row) + "\n")
4415 # Check for categorical, binary, continuous trait
4416 logging.info(f"Checking type of column: {column}")
4418 observations = OrderedDict()
4419 for row in table_rows:
4420 sample, val = row[sample_i], row[column_i]
4421 observations[sample] = val
4423 all_samples = list(observations.keys())
4424 unique_observations = sorted(list(set(observations.values())))
4426 if len(unique_observations) <=1 :
4427 msg = f"Column {column} must have at least two different values."
4428 logging.error(msg)
4429 raise Exception(msg)
4431 # Now remove missing data
4432 unique_observations = [o for o in unique_observations if o != ""]
4433 if unique_observations == ["0", "1"] or unique_observations == ["0"] or unique_observations == ["1"]:
4434 logging.info(f"Column {column} is a binary variable with values 0/1.")
4435 unique_observations = [column]
4436 elif continuous == True:
4437 logging.info(f"Treating column {column} as continuous.")
4438 try:
4439 _check = [float(o) for o in unique_observations]
4440 except:
4441 msg = f"Failed to convert all values of {column} as numeric."
4442 logging.error(msg)
4443 raise Exception(msg)
4444 unique_observations = [column]
4445 args += " --continuous"
4446 else:
4447 logging.info(f"Binarizing categorical column {column} to multiple binary columns.")
4448 binarize_prefix = f"{prefix}{column}"
4449 binarize(table = table_tsv_path, column = column, column_prefix=f"{column}_", outdir=output_dir, prefix=binarize_prefix)
4450 table_tsv_path = os.path.join(output_dir, f"{binarize_prefix}.binarize.tsv")
4451 unique_observations = [f"{column}_{o}" for o in unique_observations]
4453 # -------------------------------------------------------------------------
4454 # (optional) Cluster annotations
4456 clusters = OrderedDict()
4457 clusters_header = []
4459 if clusters_path != None:
4460 logging.info(f"Reading summarized clusters: {clusters_path}")
4461 with open(clusters_path) as infile:
4462 clusters_header = [line.strip() for line in infile.readline().split("\t")]
4464 lines = infile.readlines()
4465 for line in tqdm(lines):
4466 row = [l.strip() for l in line.split("\t")]
4467 # Store the cluster data, exclude all sample columns, they will
4468 # be provided by the variant input
4469 data = {
4470 k:v for k,v in zip(clusters_header, row)
4471 if k not in exclude_samples and k not in all_samples
4472 }
4473 cluster = data["cluster"]
4474 clusters[cluster] = data
4476 # Exclude samples from clusters_header
4477 clusters_header = [
4478 col for col in clusters_header
4479 if col not in exclude_samples and col not in all_samples
4480 ]
4482 # -------------------------------------------------------------------------
4483 # Filter variants
4485 # Exclude missing samples and invariants
4486 # TBD: Might want to convert 'missing' chars all to '.'
4488 logging.info(f"Reading variants: {variants_path}")
4489 variants_filter_path = os.path.join(output_dir, f"{prefix}{column}.filter.Rtab")
4490 logging.info(f"Writing filtered variants: {variants_filter_path}")
4492 variants = {}
4493 variants_header = []
4494 exclude_variants = set()
4496 with open(variants_filter_path, "w") as outfile:
4497 with open(variants_path) as infile:
4498 header = infile.readline()
4499 # Save the header before filtering, to match up values to samples
4500 original_header = header.strip().split("\t")[1:]
4501 # Filter samples in the variant files
4502 variants_header = [s for s in original_header if s not in exclude_samples]
4503 outfile.write("\t".join(["Variant"] + variants_header) + "\n")
4505 # Filter invariants
4506 lines = infile.readlines()
4507 for line in tqdm(lines):
4508 row = [r.strip() for r in line.split("\t")]
4509 variant = row[0]
4510 observations = [
4511 o
4512 for s,o in zip(original_header, row[1:])
4513 if s not in exclude_samples
4514 ]
4515 unique_obs = set([o for o in observations])
4516 # Exclude simple invariants (less two possibilities) and invariants
4517 # with missing, ex. [".", "0"], [".", "1"]
4518 if (len(unique_obs) < 2) or (len(unique_obs) == 2 and ("." in unique_obs)):
4519 exclude_variants.add(variant)
4520 else:
4521 variants[variant] = observations
4522 filter_row = [variant] + observations
4523 outfile.write("\t".join(filter_row) + "\n")
4525 logging.info(f"Number of invariants excluded: {len(exclude_variants)}")
4526 logging.info(f"Number of variants included: {len(variants)}")
4527 variants_path = variants_filter_path
4529 # -------------------------------------------------------------------------
4530 # Distance Matrices
4532 patristic_path = os.path.join(output_dir, f"{prefix}{column}.patristic.tsv")
4533 kinship_path = os.path.join(output_dir, f"{prefix}{column}.kinship.tsv")
4535 if "--lmm" in args:
4536 logging.info("Running pyseer with similarity matrix due to: --lmm")
4537 args += f" --similarity {kinship_path}"
4538 if lineage_path:
4539 logging.info("Running pyseer with patristic distances due to: --lineage-column")
4540 args += f" --distances {patristic_path}"
4542 # The user might want to run with --no-distances if population structure (lineage)
4543 # is the variable of interest.
4544 if ("--distances" not in args and
4545 "--similarity" not in args):
4546 logging.info("Running pyseer with --no-distances")
4547 #args += " --no-distances"
4548 # Otherwise, we setup the distance matrices as needed
4549 else:
4550 logging.info(f"Creating distance matrices: {tree_path}")
4551 # A modified version of pyseer's phylogeny_distance function
4552 # Source: https://github.com/mgalardini/pyseer/blob/master/scripts/phylogeny_distance.py
4553 tree = read_tree(tree_path, tree_format="newick")
4555 if exclude_missing:
4556 logging.info(f"Excluding missing samples from the tree: {tree_path}")
4557 tree.prune_taxa_with_labels(labels=exclude_samples)
4558 tree.prune_leaves_without_taxa()
4560 tree_path = os.path.join(output_dir, f"{prefix}{column}.filter.nwk")
4561 tree.write(path=tree_path, schema="newick", suppress_rooting=True)
4563 if midpoint == True:
4564 tree_path = os.path.join(output_dir, f"{prefix}{column}.midpoint.nwk")
4565 logging.info(f"Rerooting tree at midpoint: {tree_path}")
4566 tree = root_tree_midpoint(tree)
4567 tree.write(path=tree_path, schema="newick", suppress_rooting=True)
4569 # Reread in tree after all that adjustments
4570 tree = read_tree(tree_path, tree_format="newick")
4572 patristic = OrderedDict()
4573 kinship = OrderedDict()
4574 distance_matrix = tree.phylogenetic_distance_matrix()
4576 for n1 in tree.taxon_namespace:
4577 if n1.label not in all_samples:
4578 logging.warning(f"Sample {n1.label} is present in the tree but not in the phenotypes, excluding from distance matrices.")
4579 continue
4580 kinship[n1.label] = kinship.get(n1.label, OrderedDict())
4581 patristic[n1.label] = patristic.get(n1.label, OrderedDict())
4582 for n2 in tree.taxon_namespace:
4583 if n2.label not in all_samples:
4584 continue
4585 # Measure kinship/similarity as distance to root of MRCA
4586 # TBD: Investigating impact of precision/significant digits
4587 # https://github.com/mgalardini/pyseer/issues/286
4588 if n2.label not in kinship[n1.label].keys():
4589 mrca = distance_matrix.mrca(n1, n2)
4590 distance = mrca.distance_from_root()
4591 kinship[n1.label][n2.label] = round(distance, 8)
4592 # Measure patristic distance between the nodes
4593 if n2.label not in patristic[n1.label].keys():
4594 distance = distance_matrix.patristic_distance(n1, n2)
4595 patristic[n1.label][n2.label] = round(distance, 8)
4597 if len(patristic) == 0 or len(kinship) == 0:
4598 msg = "No samples are present in the distance matrices! Please check that your table sample labels match the tree."
4599 logging.error(msg)
4600 raise Exception(msg)
4602 logging.info(f"Saving patristic distances to: {patristic_path}")
4603 logging.info(f"Saving similarity kinship to: {kinship_path}")
4605 with open(patristic_path, 'w') as patristic_file:
4606 with open(kinship_path, 'w') as kinship_file:
4607 distance_samples = [s for s in all_samples if s in kinship]
4608 distance_header = "\t" + "\t".join(distance_samples)
4609 patristic_file.write(distance_header + "\n")
4610 kinship_file.write(distance_header + "\n")
4611 # Row order based on all_samples for consistency
4612 for s1 in distance_samples:
4613 # Kinship/similarity
4614 row = [s1] + [str(kinship[s1][s2]) for s2 in distance_samples]
4615 kinship_file.write("\t".join(row) + "\n")
4616 # Patristic
4617 row = [s1] + [str(patristic[s1][s2]) for s2 in distance_samples]
4618 patristic_file.write("\t".join(row) + "\n")
4620 # -------------------------------------------------------------------------
4621 # Pyseer GWAS
4623 for new_column in unique_observations:
4624 logging.info(f"Running GWAS on column: {new_column}")
4626 log_path = os.path.join(output_dir, f"{prefix}{new_column}.pyseer.log")
4627 focal_path = os.path.join(output_dir, f"{prefix}{new_column}.focal.txt")
4628 patterns_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.patterns.txt")
4629 locus_effects_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.tsv")
4630 locus_significant_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.significant.tsv")
4631 locus_filter_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.significant.filter.tsv")
4632 lineage_effects_path = os.path.join(output_dir, f"{prefix}{new_column}.lineage_effects.tsv")
4633 lineage_significant_path = os.path.join(output_dir, f"{prefix}{new_column}.lineage_effects.significant.tsv")
4634 bonferroni_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.bonferroni.txt")
4635 cmd = textwrap.dedent(
4636 f"""\
4637 pyseer {args}
4638 --pres {variants_path}
4639 --phenotypes {table_tsv_path}
4640 --phenotype-column {new_column}
4641 --output-patterns {patterns_path}
4642 """)
4643 if lineage_path:
4644 cmd += f" --lineage --lineage-clusters {lineage_path} --lineage-file {lineage_effects_path}"
4645 cmd = cmd.replace("\n", " ")
4646 cmd_pretty = cmd + f" 1> {locus_effects_path} 2> {log_path}"
4647 logging.info(f"pyseer command: {cmd_pretty}")
4648 run_cmd(cmd, output=locus_effects_path, err=log_path, quiet=True, display_cmd=False)
4650 # Record which samples are 'positive' for this trait
4651 positive_samples = []
4652 negative_samples = []
4653 missing_samples = []
4654 with open(table_tsv_path) as infile:
4655 header = [l.strip() for l in infile.readline().split("\t")]
4656 for line in infile:
4657 row = [l.strip() for l in line.split("\t")]
4658 data = {k:v for k,v in zip(header, row)}
4659 sample, observation = data["sample"], data[new_column]
4660 if observation != ".":
4661 if float(observation) != 0:
4662 positive_samples.append(sample)
4663 else:
4664 negative_samples.append(sample)
4665 else:
4666 missing_samples.append(sample)
4668 logging.info(f"{len(positive_samples)}/{len(all_samples)} samples are positive (>0) for {new_column}.")
4669 logging.info(f"{len(negative_samples)}/{len(all_samples)} samples are negative (=0) for {new_column}.")
4670 logging.info(f"{len(missing_samples)}/{len(all_samples)} samples are missing (.) for {new_column}.")
4672 with open(focal_path, 'w') as outfile:
4673 outfile.write("\n".join(positive_samples))
4674 # -------------------------------------------------------------------------
4675 # Significant threshold
4677 logging.info(f"Determining significance threshold.")
4679 with open(patterns_path) as infile:
4680 patterns = list(set([l.strip() for l in infile.readlines()]))
4681 num_patterns = len(patterns)
4682 logging.info(f"Number of patterns: {num_patterns}")
4683 bonferroni = 0.05 / float(num_patterns) if num_patterns > 0 else 0
4684 logging.info(f"Bonferroni threshold (0.05 / num_patterns): {bonferroni}")
4685 with open(bonferroni_path, 'w') as outfile:
4686 outfile.write(str(bonferroni) + "\n")
4688 # -------------------------------------------------------------------------
4689 # Extract variants into dict
4691 logging.info(f"Extracting variants.")
4693 gwas_variants = OrderedDict()
4694 # Keep track of the smallest non-zero pvalue for the log10 transformation
4695 # of 0 values
4696 min_pvalue = 1
4697 all_pvalues = set()
4699 with open(locus_effects_path) as infile:
4700 locus_effects_header = infile.readline().strip().strip().split("\t")
4701 for line in infile:
4702 row = [l.strip() for l in line.split("\t")]
4703 data = {k:v for k,v in zip(locus_effects_header, row)}
4704 data["notes"] = ",".join(sorted(data["notes"].split(",")))
4705 # Convert pvalue
4706 pvalue = data["lrt-pvalue"]
4707 if pvalue != "":
4708 pvalue = float(pvalue)
4709 # Keep track of the smallest non-zero pvalue
4710 if pvalue != 0:
4711 min_pvalue = min(pvalue, min_pvalue)
4712 all_pvalues.add(pvalue)
4713 data["-log10(p)"] = ""
4714 data["bonferroni"] = bonferroni
4715 variant = data["variant"]
4716 gwas_variants[variant] = data
4717 locus_effects_header += ["-log10(p)", "bonferroni"]
4719 logging.info(f"Minimum pvalue observed: {min(all_pvalues)}")
4720 logging.info(f"Minimum pvalue observed (non-zero): {min_pvalue}")
4722 # -------------------------------------------------------------------------
4723 # -log10 transformation
4724 logging.info("Applying -log10 transformation.")
4726 # For p-values that are 0, we will use this tiny value instead
4727 # Remember, the min_pvalue is the minimum pvalue that is NOT 0
4728 small_val = float("1.00E-100")
4729 if small_val < min_pvalue:
4730 if 0 in all_pvalues:
4731 logging.warning(f"Using small float for pvalue=0 transformations: {small_val}")
4732 else:
4733 logging.warning(f"Using pvalue min for pvalue=0 transformations: {min_pvalue}")
4734 small_val = min_pvalue
4736 for variant, v_data in gwas_variants.items():
4737 pvalue = v_data["lrt-pvalue"]
4738 if pvalue == "": continue
4739 pvalue = float(pvalue)
4740 if pvalue == 0:
4741 logging.warning(f"{variant} has a pvalue of 0. log10 transformation will use: {small_val}")
4742 pvalue = small_val
4743 gwas_variants[variant]["-log10(p)"] = -math.log10(pvalue)
4745 # -------------------------------------------------------------------------
4746 # Annotate and sort the output files
4748 # Extract the cluster information for each variant
4749 logging.info(f"Extracting cluster identifiers.")
4750 locus_effects = OrderedDict()
4752 for variant,v_data in gwas_variants.items():
4753 variant_split = variant.split("|")
4754 cluster = variant_split[0]
4755 row = list(v_data.values())
4756 if cluster not in locus_effects:
4757 locus_effects[cluster] = []
4758 locus_effects[cluster].append(row)
4760 logging.info(f"Adding cluster annotations and variant observations.")
4761 locus_effects_rows = []
4762 match_header = ["match", "mismatch", "score"]
4763 for cluster in tqdm(locus_effects):
4764 cluster_row = [clusters[cluster][col] for col in clusters_header]
4765 for effect_row in locus_effects[cluster]:
4766 effect_data = {k:v for k,v in zip(locus_effects_header, effect_row)}
4767 try:
4768 beta = float(effect_data["beta"])
4769 except ValueError:
4770 beta = ""
4771 variant = effect_row[0]
4772 variant_row = variants[variant]
4773 variant_data = {k:v for k,v in zip(variants_header, variant_row)}
4774 # A simple summary statistic of how well the variant observations match the phenotype
4775 denom = len(all_samples)
4776 # Presence of variant should 'increase' or 'decrease' phenotype
4777 if beta == "" or beta > 0:
4778 v1, v2 = ["1", "0"]
4779 else:
4780 v1, v2 = ["0", "1"]
4781 match_num = len([
4782 k for k,v in variant_data.items()
4783 if (v == v1 and k in positive_samples) or (v == v2 and k in negative_samples)
4784 ])
4785 mismatch_num = len([
4786 k for k,v in variant_data.items()
4787 if (v == v1 and k not in positive_samples) or (v == v2 and k not in negative_samples)
4788 ])
4790 match_value = f"{match_num}/{denom}|{round(match_num / denom, 2)}"
4791 mismatch_value = f"{mismatch_num}/{denom}|{round(mismatch_num / denom, 2)}"
4792 match_score = (match_num/denom) - (mismatch_num/denom)
4793 # TBD What if we are on operating on lrt-filtering-fail p values?
4794 # We have no way to know the beta direction, and which way the score should go...
4795 if beta == "":
4796 match_score = abs(match_score)
4797 match_row = [match_value, mismatch_value, round(match_score, 2)]
4798 locus_effects_rows += [effect_row + match_row + cluster_row + variant_row]
4800 logging.info(f"Sorting locus effects by lrt-pvalue: {locus_effects_path}")
4801 locus_pvalues = OrderedDict()
4802 locus_null = OrderedDict()
4803 for i,row in enumerate(locus_effects_rows):
4804 data = {k:v for k,v in zip(locus_effects_header, row)}
4805 variant, lrt_pvalue = data["variant"], data["lrt-pvalue"]
4806 if lrt_pvalue == "":
4807 if lrt_pvalue not in locus_null:
4808 locus_null[lrt_pvalue] = []
4809 locus_null[lrt_pvalue].append((variant, i))
4810 else:
4811 lrt_pvalue = float(lrt_pvalue)
4812 if lrt_pvalue not in locus_pvalues:
4813 locus_pvalues[lrt_pvalue] = []
4814 locus_pvalues[lrt_pvalue].append((variant, i))
4815 locus_pvalues = sorted(locus_pvalues.items(), key=lambda kv: (kv[0], kv[1]))
4816 locus_null = sorted(locus_null.items(), key=lambda kv: (kv[0], kv[1]))
4817 locus_order = []
4818 for data in locus_pvalues + locus_null:
4819 for variant, i in data[1]:
4820 locus_order.append(i)
4821 locus_effects_rows = [locus_effects_rows[i] for i in locus_order]
4823 # Save the results
4824 with open(locus_effects_path, 'w') as outfile:
4825 header = "\t".join(locus_effects_header + match_header + clusters_header + variants_header)
4826 outfile.write(header + "\n")
4827 for row in locus_effects_rows:
4828 line = "\t".join([str(r) for r in row])
4829 outfile.write(line + "\n")
4831 logging.info(f"Sorting patterns: {patterns_path}")
4832 with open(patterns_path) as infile:
4833 lines = sorted([l for l in infile.readlines()])
4834 with open(patterns_path, 'w') as outfile:
4835 outfile.write("".join(lines))
4837 if lineage_path:
4838 logging.info(f"Sorting lineage effects: {lineage_effects_path}")
4839 with open(lineage_effects_path) as infile:
4840 header = infile.readline()
4841 lines = sorted([l for l in infile.readlines()])
4842 with open(lineage_effects_path, 'w') as outfile:
4843 outfile.write(header + "".join(lines))
4845 # -------------------------------------------------------------------------
4846 # Locus Effects
4848 logging.info("Identifying significant locus effects.")
4849 with open(locus_effects_path) as infile:
4850 header_line = infile.readline()
4851 header = header_line.strip().split("\t")
4852 with open(locus_significant_path, 'w') as significant_outfile:
4853 significant_outfile.write(header_line)
4854 with open(locus_filter_path, 'w') as filter_outfile:
4855 filter_outfile.write(header_line)
4856 for line in infile:
4857 row = [l.strip() for l in line.split("\t")]
4858 data = {k:v for k,v in zip(header, row)}
4859 lrt_pvalue, notes = data["lrt-pvalue"], data["notes"]
4860 if lrt_pvalue != "" and float(lrt_pvalue) < bonferroni:
4861 significant_outfile.write(line)
4862 if "bad-chisq" not in notes and "high-bse" not in notes:
4863 filter_outfile.write(line)
4865 # -------------------------------------------------------------------------
4866 # Lineage effects
4868 if lineage_path:
4869 logging.info("Identifying significant lineage effects.")
4870 with open(lineage_effects_path) as infile:
4871 header_line = infile.readline()
4872 header = header_line.strip().split("\t")
4873 pvalue_i = header.index("p-value")
4874 with open(lineage_significant_path, 'w') as significant_outfile:
4875 significant_outfile.write(header_line)
4876 for line in infile:
4877 row = [l.strip() for l in line.split("\t")]
4878 pvalue = float(row[pvalue_i])
4879 if pvalue < bonferroni:
4880 significant_outfile.write(line)
4882 # -------------------------------------------------------------------------
4883 # QQ Plot
4885 if len(locus_effects_rows) == 0:
4886 logging.info("Skipping QQ Plot as no variants were observed.")
4887 else:
4888 logging.info("Creating QQ Plot.")
4889 _plot_path = qq_plot(
4890 locus_effects = locus_effects_path,
4891 outdir = output_dir,
4892 prefix = f"{prefix}{new_column}"
4893 )
4896def text_to_path(text, tmp_path="tmp.svg", family="Roboto", size=16, clean=True):
4897 """
4898 Convert a string of text to SVG paths.
4900 :param text: String of text.
4901 :param tmp_path: File path to temporary file to render svg.
4902 :param family: Font family.
4903 :param size: Font size.
4904 :param clean: True if temporary file should be deleted upon completion.
4905 """
4907 import cairo
4908 from xml.dom import minidom
4909 from svgpathtools import parse_path
4912 # Approximate and appropriate height and width for the tmp canvas
4913 # This just needs to be at least big enough to hold it
4914 w = size * len(text)
4915 h = size * 2
4916 # Render text as individual glyphs
4917 with cairo.SVGSurface(tmp_path, w, h) as surface:
4918 context = cairo.Context(surface)
4919 context.move_to(0, h/2)
4920 context.set_font_size(size)
4921 context.select_font_face(family)
4922 context.show_text(text)
4924 # Parse text data and positions from svg DOM
4925 doc = minidom.parse(tmp_path)
4927 # Keep track of overall text bbox
4928 t_xmin, t_xmax, t_ymin, t_ymax = None, None, None, None
4929 glyphs = []
4931 # Absolute positions of each glyph are under <use>
4932 # ex. <use xlink:href="#glyph-0-1" x="11.96875" y="16"/>
4933 for u in doc.getElementsByTagName('use'):
4934 glyph_id = u.getAttribute("xlink:href").replace("#", "")
4935 # Get absolute position of the glyph
4936 x = float(u.getAttribute("x"))
4937 y = float(u.getAttribute("y"))
4938 # Get relative path data of the glyph
4939 for g in doc.getElementsByTagName('g'):
4940 g_id = g.getAttribute("id")
4941 if g_id != glyph_id: continue
4942 for path in g.getElementsByTagName('path'):
4943 d = path.getAttribute("d")
4944 p = parse_path(d)
4945 # Get glyph relative bbox
4946 xmin, xmax, ymin, ymax = [c for c in p.bbox()]
4947 # Convert bbox to absolute coordinates
4948 xmin, xmax, ymin, ymax = xmin + x, xmax + x, ymin + y, ymax + y
4949 # Update the text bbox
4950 if t_xmin == None:
4951 t_xmin, t_xmax, t_ymin, t_ymax = xmin, xmax, ymin, ymax
4952 else:
4953 t_xmin = min(t_xmin, xmin)
4954 t_xmax = max(t_xmax, xmax)
4955 t_ymin = min(t_ymin, ymin)
4956 t_ymax = max(t_ymax, ymax)
4957 bbox = xmin, xmax, ymin, ymax
4959 w, h = xmax - xmin, ymax - ymin
4960 glyph = {"id": glyph_id, "x": x, "y": y, "w": w, "h": h, "d": d, "bbox": bbox}
4961 glyphs.append(glyph)
4963 if not t_xmax:
4964 logging.error(f"Failed to render text: {text}")
4965 raise Exception(f"Failed to render text: {text}")
4966 # Calculate the final dimensions of the entire text
4967 w, h = t_xmax - t_xmin, t_ymax - t_ymin
4968 bbox = t_xmin, t_xmax, t_ymin, t_ymax
4970 result = {"text": text, "glyphs": glyphs, "w": w, "h": h, "bbox": bbox }
4972 if clean:
4973 os.remove(tmp_path)
4975 return result
4978def linear_scale(value, value_min, value_max, target_min=0.0, target_max=1.0):
4979 if value_min == value_max:
4980 return value
4981 else:
4982 return (value - value_min) * (target_max - target_min) / (value_max - value_min) + target_min
4984def manhattan(
4985 gwas: str,
4986 bed: str,
4987 outdir: str = ".",
4988 prefix: str = None,
4989 font_size: int = 16,
4990 font_family: str ="Roboto",
4991 margin: int = 20,
4992 width: int = 1000,
4993 height: int = 500,
4994 png_scale: float = 2.0,
4995 prop_x_axis: bool = False,
4996 ymax: float = None,
4997 max_blocks: int = 20,
4998 syntenies: bool = ["all"],
4999 clusters: bool = ["all"],
5000 variant_types: list = ["all"],
5001 args: str = None,
5002 ):
5003 """
5004 Plot the distribution of variant p-values across the genome.
5006 Takes as input a table of locus effects from the subcommand and a bed
5007 file such as the one producted by the align subcommand. Outputs a
5008 manhattan plot in SVG and PNG format.
5010 >>> manhattan(gwas="locus_effects.tsv", bed="pangenome.bed")
5011 >>> manhattan(gwas="locus_effects.tsv", bed="pangenome.bed", syntenies=["chromosome"], clusters=["pbpX"], variant_types=["snp", "presence_absence"])
5013 :param gwas: Path to tsv table of locus effects from gwas subcommand.
5014 :param bed: Path to BED file of coordinates.
5015 :param outdir: Output directory.
5016 :param prefix: Output prefix.
5017 :param width: Width in pixels of plot.
5018 :param height: Height in pixels of plot.
5019 :param margin: Size in pixels of plot margins.
5020 :param font_size: Font size.
5021 :param font_family: Font family.
5022 :param png_scale: Float that adjusts png scale relative to svg.
5023 :param prop_x_axis: Scale x-axis based on the length of the synteny block.
5024 :param ymax: Maximum value for the y-axis -log10(p).
5025 :param sytenies: Names of synteny blocks to plot.
5026 :param clusters: Names of clusters to plot.
5027 :params variant_types: Names of variant types to plot.
5028 """
5030 import numpy as np
5032 logging.info(f"Importing cairo.")
5033 import cairosvg
5035 gwas_path, bed_path, output_dir = gwas, bed, outdir
5036 prefix = f"{prefix}." if prefix != None else ""
5037 check_output_dir(outdir)
5039 # Type conversions fallback
5040 width, height, margin = int(width), int(height), int(margin)
5041 png_scale = float(png_scale)
5042 if ymax != None: ymax = float(ymax)
5044 # handle space values if given
5045 syntenies = syntenies if type(syntenies) != str else syntenies.split(" ")
5046 syntenies = [str(s) for s in syntenies]
5048 clusters = clusters if type(clusters) != str else clusters.split(" ")
5049 clusters = [str(c) for c in clusters]
5051 variant_types = variant_types if type(variant_types) != str else variant_types.split(" ")
5052 variant_types = [str(v) for v in variant_types]
5054 plot_data = OrderedDict()
5056 # -------------------------------------------------------------------------
5057 # Parse BED (Synteny Blocks)
5059 pangenome_length = 0
5061 logging.info(f"Reading BED coordinates: {bed_path}")
5062 with open(bed_path) as infile:
5063 for line in infile:
5064 row = [l.strip() for l in line.split("\t")]
5065 # Adjust 0-based bed coordinates back to 1 base
5066 c_start, c_end, name = int(row[1]) + 1, int(row[2]), row[3]
5067 pangenome_length = c_end
5068 c_length = (c_end - c_start) + 1
5069 # ex. cluster=geneA;synteny=443
5070 info = {n.split("=")[0]:n.split("=")[1] for n in name.split(";")}
5071 cluster, synteny = info["cluster"], info["synteny"]
5073 if synteny not in plot_data:
5074 plot_data[synteny] = {
5075 "pangenome_pos": [c_start, c_end],
5076 "length": c_length,
5077 "variants": OrderedDict(),
5078 "clusters": OrderedDict(),
5079 "ticks": OrderedDict(),
5080 "prop": 0.0,
5081 }
5083 # update the end position and length of the current synteny block
5084 plot_data[synteny]["pangenome_pos"][1] = c_end
5085 plot_data[synteny]["length"] = (
5086 c_end - plot_data[synteny]["pangenome_pos"][0]
5087 ) + 1
5089 s_start = plot_data[synteny]["pangenome_pos"][0]
5090 c_data = {
5091 "pangenome_pos": [c_start, c_end],
5092 "synteny_pos": [1 + c_start - s_start, 1+ c_end - s_start],
5093 "cluster_pos": [1, c_length],
5094 "length": c_length,
5095 }
5096 plot_data[synteny]["clusters"][cluster] = c_data
5098 # -------------------------------------------------------------------------
5099 # Parse GWAS (Variants)
5101 logging.info(f"Reading GWAS table: {gwas_path}")
5103 alpha = None
5104 log10_pvalues = set()
5106 with open(gwas_path) as infile:
5108 header = [l.strip()
5109 for l in infile.readline().split("\t")]
5110 if "cluster" not in header:
5111 msg = "GWAS table must contain cluster annotations for manhattan plot."
5112 logging.error(msg)
5113 raise Exception(msg)
5115 lines = infile.readlines()
5117 if len(lines) == 0:
5118 msg = "GWAS table contains no variants."
5119 logging.warning(msg)
5120 return 0
5122 for line in tqdm(lines):
5123 row = [l.strip() for l in line.split("\t")]
5124 data = {k:v for k,v in zip(header, row)}
5125 variant, cluster, synteny = data["variant"], data["cluster"], data["synteny"]
5127 alpha = float(data["bonferroni"]) if alpha == None else alpha
5129 if data["lrt-pvalue"] == "": continue
5130 pvalue = float(data["lrt-pvalue"])
5131 log10_pvalue = float(data["-log10(p)"])
5132 log10_pvalues.add(log10_pvalue)
5134 # try to find matching synteny block by name
5135 c_data = None
5136 if synteny in plot_data:
5137 c_data = plot_data[synteny]["clusters"][cluster]
5138 # Otherwise, find by cluster
5139 else:
5140 for s,s_data in plot_data.items():
5141 if cluster in s_data["clusters"]:
5142 synteny = s
5143 c_data = s_data["clusters"][cluster]
5145 if c_data == None:
5146 msg = f"Synteny block {synteny} for {cluster} is not present in the BED file."
5147 logging.error(msg)
5148 raise Exception(msg)
5150 s_data = plot_data[synteny]
5151 s_start, s_end = s_data["pangenome_pos"]
5152 c_start, c_end = c_data["synteny_pos"]
5154 # 3 different coordinates systems
5155 cluster_coords = []
5156 synteny_coords = []
5157 pangenome_coords = []
5159 if "|snp" in variant:
5160 variant_type = "snp"
5161 snp = variant.split("|")[1].split(":")[1]
5162 pos = int("".join([c for c in snp if c.isnumeric()]))
5163 cluster_pos = [pos, pos]
5164 cluster_coords.append(cluster_pos)
5165 elif "|presence_absence" in variant:
5166 variant_type = "presence_absence"
5167 cluster_pos = c_data["cluster_pos"]
5168 cluster_coords.append(cluster_pos)
5169 elif "|structural" in variant:
5170 variant_type = "structural"
5171 variant_split = variant.split("|")[1].split(":")[1]
5172 for interval in variant_split.split("_"):
5173 cluster_pos = [int(v) for v in interval.split("-")]
5174 cluster_coords.append(cluster_pos)
5175 else:
5176 logging.warning(f"Skipping unknown variant type: {variant}")
5177 continue
5179 # Apply filter
5180 if "all" not in variant_types and variant_type not in variant_types: continue
5181 if "all" not in syntenies and synteny not in syntenies: continue
5182 if "all" not in clusters and cluster not in clusters: continue
5184 # Convert cluster coords to synteny and pangenome coords
5185 for pos in cluster_coords:
5186 synteny_pos = [c_start + pos[0], c_start + pos[1]]
5187 synteny_coords.append(synteny_pos)
5188 pangenome_pos = [s_start + synteny_pos[0], s_start + synteny_pos[1]]
5189 pangenome_coords.append(pangenome_pos)
5191 plot_data[synteny]["variants"][variant] = {
5192 "variant": variant,
5193 "synteny": synteny,
5194 "synteny_pos": int(data["synteny_pos"]),
5195 "cluster": cluster,
5196 "pvalue": pvalue,
5197 "-log10(p)": log10_pvalue,
5198 "variant_h2": float(data["variant_h2"]),
5199 "type": variant_type,
5200 "cluster_coords": cluster_coords,
5201 "synteny_coords": synteny_coords,
5202 "pangenome_coords": pangenome_coords,
5203 }
5205 alpha_log10 = -math.log10(alpha)
5206 if ymax == None:
5207 max_log10 = math.ceil(max(log10_pvalues) if max(log10_pvalues) > alpha_log10 else alpha_log10)
5208 else:
5209 max_log10 = ymax
5211 min_log10 = 0
5212 log10_vals= np.arange(min_log10, max_log10 + 1, max_log10 / 4)
5214 # -------------------------------------------------------------------------
5215 # Exclude synteny blocks with no variants
5216 plot_data = OrderedDict({k:v for k,v in plot_data.items() if len(v["variants"]) > 0})
5218 # Adjust the total length based on just the synteny blocks we've observed
5219 total_length = sum(
5220 [sum([c["length"] for c in s["clusters"].values()])
5221 for s in plot_data.values()]
5222 )
5224 observed_clusters = set()
5225 for s in plot_data.values():
5226 for v in s["variants"].values():
5227 observed_clusters.add(v["cluster"])
5228 for s,s_data in plot_data.items():
5229 plot_data[s]["clusters"] = OrderedDict({
5230 k:v for k,v in plot_data[s]["clusters"].items()
5231 if k in observed_clusters
5232 })
5234 # If there's just a single cluster, adjust the total length
5235 if len(observed_clusters) == 1:
5236 only_synteny = list(plot_data.keys())[0]
5237 only_cluster = list(observed_clusters)[0]
5238 total_length = plot_data[only_synteny]["clusters"][only_cluster]["length"]
5239 elif "all" not in clusters:
5240 total_length = sum(
5241 [sum([c["length"] for c in s["clusters"].values()])
5242 for s in plot_data.values()]
5243 )
5245 if len(plot_data) == 0:
5246 msg = "No variants remain after filtering."
5247 logging.warning(msg)
5248 return None
5250 for s,s_data in plot_data.items():
5251 s_length = sum([c["length"] for c in s_data["clusters"].values()])
5252 s_prop = s_length / total_length
5253 plot_data[s]["prop"] = s_prop
5255 # -------------------------------------------------------------------------
5256 # Phandango data
5258 phandango_path = os.path.join(output_dir, f"{prefix}phandango.plot")
5259 logging.info(f"Creating phandango input: {phandango_path}")
5261 with open(phandango_path, 'w') as outfile:
5262 header = ["#CHR", "SNP", "BP", "minLOG10(P)", "log10(p)", "r^2"]
5263 outfile.write("\t".join(header) + "\n")
5265 for synteny,s_data in plot_data.items():
5266 pangenome_pos = s_data["pangenome_pos"]
5267 for variant, v_data in s_data["variants"].items():
5268 log10_pvalue = v_data["-log10(p)"]
5269 variant_h2 = v_data["variant_h2"]
5270 cluster = v_data["cluster"]
5271 c_data = s_data["clusters"][cluster]
5273 # If a snp, we take the actual position
5274 # for presence/absence or structural, we take the start
5275 if "|snp" in variant:
5276 pos = v_data["pangenome_coords"][0][0]
5277 else:
5278 pos = c_data["pangenome_pos"][0]
5279 row = ["pangenome", variant, pos, log10_pvalue, log10_pvalue, variant_h2]
5280 line = "\t".join([str(v) for v in row])
5281 outfile.write(line + "\n")
5283 # -------------------------------------------------------------------------
5284 # Y-Axis Label Dimensions
5286 logging.info(f"Calculating y-axis label dimensions.")
5288 # Main y-axis label
5289 y_axis_text = "-log10(p)"
5290 y_axis_label = text_to_path(text=y_axis_text, size=font_size, family=font_family)
5291 y_axis_label_h = y_axis_label["h"]
5292 y_axis_label_x = margin + y_axis_label_h
5293 y_axis_label["rx"] = y_axis_label_x
5294 # To set the y position, we'll need to wait until after x axis labels are done
5296 y_tick_hmax = 0
5297 y_tick_wmax = 0
5298 y_tick_labels = OrderedDict()
5300 for val in log10_vals:
5301 label = text_to_path(text=str(val), size=font_size * 0.75, family=font_family)
5302 w, h = label["w"], label["h"]
5303 y_tick_hmax, y_tick_wmax = max(y_tick_hmax, h), max(y_tick_wmax, w)
5304 y_tick_labels[val] = label
5306 # We now know some initial dimensions
5307 tick_len = y_tick_hmax / 2
5308 tick_pad = tick_len
5309 y_axis_x = y_axis_label_x + y_axis_label_h + y_tick_wmax + tick_pad + tick_len
5310 y_axis_y1 = margin
5311 # To set the y2 position, we'll need to wait until after x axis labels are done
5313 # -------------------------------------------------------------------------
5314 # X-Axis Label Dimensions
5316 logging.info(f"Calculating x-axis label dimensions.")
5318 # X-axis main label
5319 if len(plot_data) > max_blocks:
5320 x_axis_text = "Pangenome"
5321 elif len(plot_data) == 1:
5322 x_axis_text = str(list(plot_data.keys())[0])
5323 if len(observed_clusters) == 1:
5324 x_axis_text = str(list(observed_clusters)[0])
5325 else:
5326 x_axis_text = "Synteny Block"
5328 x_axis_label = text_to_path(text=x_axis_text, size=font_size, family=font_family)
5329 x_axis_label_w, x_axis_label_h = x_axis_label["w"], x_axis_label["h"]
5330 x_axis_label_y = height - margin
5332 x_axis_x1 = y_axis_x
5333 x_axis_x2 = width - margin
5334 x_axis_w = x_axis_x2 - x_axis_x1
5335 x_axis_step = x_axis_w / len(plot_data)
5337 x_axis_label_x = x_axis_x1 + (x_axis_w / 2) - (x_axis_label_w / 2)
5339 x_tick_hmax = 0
5340 x_tick_wmax = 0
5342 start_coord, end_coord = 0, total_length
5345 if len(plot_data) > max_blocks or len(plot_data) == 1:
5347 if len(plot_data) > max_blocks:
5348 logging.info(f"Plot data exceeds max_blocks of {max_blocks}, enabling pangenome coordinates.")
5349 start_coord, end_coord = 0, pangenome_length
5350 elif len(plot_data) == 1:
5351 only_synteny = list(plot_data.keys())[0]
5352 logging.info(f"Single synteny blocked detected, enabling synteny coordinates.")
5353 # For one cluster, we might have variable start positions
5354 if len(observed_clusters) == 1:
5355 only_cluster = list(observed_clusters)[0]
5356 c_data = plot_data[only_synteny]["clusters"][only_cluster]
5357 start_coord, end_coord = c_data["cluster_pos"]
5358 # Round down/up to nearest 10?
5359 start_coord = start_coord - (start_coord % 10)
5360 end_coord = end_coord + (end_coord % 10)
5362 label = text_to_path(text="0123456789", size=font_size * 0.75, family=font_family)
5363 h = label["h"]
5364 # Put 1/2 h spacing in between
5365 max_labels = x_axis_w / (2 * h)
5366 # Prefer either 10 or 4 ticks
5367 if max_labels >= 10:
5368 max_labels = 10
5369 elif max_labels >= 4:
5370 max_labels = 4
5372 step_size = 1
5373 num_labels = max_labels + 1
5374 # 10, 50, 100, 500, 1000...
5375 while num_labels > max_labels:
5376 if str(step_size).startswith("5"):
5377 step_size = step_size * 2
5378 else:
5379 step_size = step_size * 5
5380 num_labels = total_length / step_size
5382 x_tick_vals = list(range(start_coord, end_coord + 1, step_size))
5384 x_axis_step = x_axis_w / num_labels
5385 for val in x_tick_vals:
5386 text = str(val)
5387 label = text_to_path(text=text, size=font_size * 0.75, family=font_family)
5388 w, h = label["w"], label["h"]
5389 x_tick_hmax, x_tick_wmax = max(x_tick_hmax, h), max(x_tick_wmax, w)
5390 plot_data[synteny]["ticks"][text] = {"label": label}
5391 else:
5392 for synteny in plot_data:
5393 label = text_to_path(text=str(synteny), size=font_size * 0.75, family=font_family)
5394 w, h = label["w"], label["h"]
5395 x_tick_hmax, x_tick_wmax = max(x_tick_hmax, h), max(x_tick_wmax, w)
5396 plot_data[synteny]["ticks"][synteny] = {"label": label}
5399 # We now have enough information to finalize the axes coordinates
5400 x_axis_y = x_axis_label_y - (x_axis_label_h * 2) - x_tick_wmax - tick_pad - tick_len
5401 y_axis_y2 = x_axis_y
5402 y_axis_h = y_axis_y2 - y_axis_y1
5403 y_axis_label_y = y_axis_y1 + (y_axis_h / 2)
5404 y_axis_label["ry"] = y_axis_label_y
5406 # -------------------------------------------------------------------------
5407 logging.info(f"Positioning labels.")
5409 for i_g,g in enumerate(x_axis_label["glyphs"]):
5410 x_axis_label["glyphs"][i_g]["x"] = x_axis_label_x + g["x"]
5411 x_axis_label["glyphs"][i_g]["y"] = x_axis_label_y
5413 for i_g,g in enumerate(y_axis_label["glyphs"]):
5414 y_axis_label["glyphs"][i_g]["x"] = y_axis_label_x + g["x"]
5415 y_axis_label["glyphs"][i_g]["y"] = y_axis_label_y
5417 # -------------------------------------------------------------------------
5418 # X-Axis
5420 logging.info("Creating x-axis.")
5422 prev_x = x_axis_x1
5423 for synteny in plot_data:
5424 s_data = plot_data[synteny]
5426 if len(plot_data) == 1:
5427 plot_data[synteny]["x1"] = x_axis_x1
5428 plot_data[synteny]["x2"] = x_axis_x2
5430 for tick,t_data in plot_data[synteny]["ticks"].items():
5432 if prop_x_axis == True:
5433 xw = x_axis_w * s_data["prop"]
5434 else:
5435 xw = x_axis_step
5437 if len(plot_data) > 1 and len(plot_data) <= max_blocks:
5438 plot_data[synteny]["x1"] = prev_x
5439 plot_data[synteny]["x2"] = prev_x + xw
5441 # Pixel coordinates of this synteny block
5442 if len(plot_data) == 1 or len(plot_data) > max_blocks:
5443 x = linear_scale(int(tick), start_coord, end_coord, x_axis_x1, x_axis_x2)
5444 else:
5445 x = prev_x + (xw / 2)
5447 # Tick Label
5448 label = t_data["label"]
5449 y = x_axis_y + tick_len + label["w"] + (tick_pad)
5450 t_data["label"]["rx"] = x
5451 t_data["label"]["ry"] = y
5453 # Tick Line
5454 t_data["line"] = {"x": x, "y1": x_axis_y, "y2": x_axis_y + tick_len}
5455 # Center label on the very last glyph
5456 center_glyph = label["glyphs"][-1]
5457 label_y = y + (center_glyph["h"] / 2)
5458 for i_g,g in enumerate(label["glyphs"]):
5459 t_data["label"]["glyphs"][i_g]["x"] = x + g["x"]
5460 t_data["label"]["glyphs"][i_g]["y"] = label_y
5462 plot_data[synteny]["ticks"][tick] = t_data
5464 prev_x += xw
5466 # -------------------------------------------------------------------------
5467 # Y-Axis
5469 logging.info("Creating y-axis.")
5471 y_axis = OrderedDict({k:{"label": v, "line": None} for k,v in y_tick_labels.items()})
5473 for val,label in y_tick_labels.items():
5474 y = linear_scale(val, max_log10, min_log10, y_axis_y1, y_axis_y2)
5475 x1, x2 = y_axis_x - tick_len, y_axis_x
5477 # Tick Line
5478 y_axis[val]["line"] = { "x1": x1, "x2": x2, "y": y }
5480 # Tick Label
5481 y_axis[val]["label"] = label
5482 center_glyph = label["glyphs"][-1]
5483 label_y = y + (center_glyph["h"] / 2)
5484 label_x = x1 - tick_pad - label["w"]
5486 for i_g,g in enumerate(label["glyphs"]):
5487 label["glyphs"][i_g]["x"] = label_x + g["x"]
5488 label["glyphs"][i_g]["y"] = label_y
5490 y_axis[val]["label"] = label
5492 # -------------------------------------------------------------------------
5493 # Render
5495 radius = 2
5497 svg_path = os.path.join(output_dir, f"{prefix}plot.svg")
5498 logging.info(f"Rendering output svg ({width}x{height}): {svg_path}")
5499 with open(svg_path, 'w') as outfile:
5500 header = textwrap.dedent(
5501 f"""\
5502 <svg
5503 version="1.1"
5504 xmlns="http://www.w3.org/2000/svg"
5505 xmlns:xlink="http://www.w3.org/1999/xlink"
5506 preserveAspectRatio="xMidYMid meet"
5507 width="{width}"
5508 height="{height}"
5509 viewbox="0 0 {width} {height}">
5510 """)
5511 outfile.write(header)
5513 indent = " "
5515 # White canvas background
5516 outfile.write(f"{indent * 1}<g id='Background'>\n")
5517 background = f"{indent * 2}<rect width='{width}' height='{height}' x='0' y='0' style='fill:white;stroke-width:1;stroke:white'/>"
5518 outfile.write(background + "\n")
5519 outfile.write(f"{indent * 1}</g>\n")
5521 # Axes
5522 outfile.write(f"{indent * 1}<g id='axis'>\n")
5524 # ---------------------------------------------------------------------
5525 # X Axis
5527 outfile.write(f"{indent * 2}<g id='x'>\n")
5529 # Background panels
5530 if len(plot_data) <= max_blocks:
5531 outfile.write(f"{indent * 3}<g id='Background Panels'>\n")
5532 for i_s,synteny in enumerate(plot_data):
5533 if i_s % 2 == 0: continue
5534 x1, x2 = plot_data[synteny]["x1"], plot_data[synteny]["x2"]
5535 y = margin
5536 w = x2 - x1
5537 h = y_axis_h
5538 rect = f"{indent * 2}<rect width='{w}' height='{h}' x='{x1}' y='{y}' style='fill:grey;fill-opacity:0.10'/>"
5539 outfile.write(rect + "\n")
5540 outfile.write(f"{indent * 3}</g>\n")
5542 # X axis line
5543 outfile.write(f"{indent * 3}<g id='line'>\n")
5544 line = f"{indent * 4}<path d='M {x_axis_x1} {x_axis_y} H {x_axis_x2}' style='stroke:black;stroke-width:2;fill:none'/>"
5545 outfile.write(line + "\n")
5546 outfile.write(f"{indent * 3}</g>\n")
5548 # X axis label
5549 outfile.write(f"{indent * 3}<g id='label'>\n")
5550 for g in x_axis_label["glyphs"]:
5551 line = f"{indent * 4}<path transform='translate({g['x']},{g['y']})' d='{g['d']}'/>"
5552 outfile.write(line + "\n")
5553 outfile.write(f"{indent * 3}</g>\n")
5555 # X axis ticks
5556 outfile.write(f"{indent * 3}<g id='ticks'>\n")
5557 for synteny in plot_data:
5558 for tick,t_data in plot_data[synteny]["ticks"].items():
5559 # Tick line
5560 t = t_data["line"]
5561 line = f"{indent * 4}<path d='M {t['x']} {t['y1']} V {t['y2']}' style='stroke:black;stroke-width:1;fill:none'/>"
5562 outfile.write(line + "\n")
5563 # Tick label
5564 label = t_data["label"]
5565 rx, ry = label["rx"], label["ry"]
5566 for g in label["glyphs"]:
5567 line = f"{indent * 4}<path transform='rotate(-90, {rx}, {ry}) translate({g['x']},{g['y']})' d='{g['d']}'/>"
5568 outfile.write(line + "\n")
5570 outfile.write(f"{indent * 3}</g>\n")
5572 # Close x-axis
5573 outfile.write(f"{indent * 2}</g>\n")
5575 # ---------------------------------------------------------------------
5576 # Y Axis
5578 outfile.write(f"{indent * 2}<g id='y-axis'>\n")
5580 # Y-axis line
5581 outfile.write(f"{indent * 3}<g id='line'>\n")
5582 line = f"{indent * 4}<path d='M {y_axis_x} {y_axis_y1} V {y_axis_y2}' style='stroke:black;stroke-width:2;fill:none'/>"
5583 outfile.write(line + "\n")
5584 outfile.write(f"{indent * 3}</g>\n")
5586 # Y axis label
5587 outfile.write(f"{indent * 3}<g id='label'>\n")
5588 for g in y_axis_label["glyphs"]:
5589 rx, ry = y_axis_label["rx"], y_axis_label["ry"]
5590 line = f"{indent * 4}<path transform='rotate(-90, {rx}, {ry}) translate({g['x']},{g['y']})' d='{g['d']}'/>"
5591 outfile.write(line + "\n")
5592 outfile.write(f"{indent * 3}</g>\n")
5594 outfile.write(f"{indent * 3}<g id='ticks'>\n")
5595 # Ticks
5596 for t,t_data in y_axis.items():
5597 # Tick line
5598 t = t_data["line"]
5599 line = f"{indent * 3}<path d='M {t['x1']} {t['y']} H {t['x2']}' style='stroke:black;stroke-width:1;fill:none'/>"
5600 outfile.write(line + "\n")
5601 # Tick Label
5602 label = t_data["label"]
5603 for g in label["glyphs"]:
5604 line = f"{indent * 4}<path transform='translate({g['x']},{g['y']})' d='{g['d']}'/>"
5605 outfile.write(line + "\n")
5606 outfile.write(f"{indent * 3}</g>\n")
5608 # Close y-axis
5609 outfile.write(f"{indent * 2}</g>\n")
5611 # Close all Axes
5612 outfile.write(f"{indent * 1}</g>\n")
5615 # -------------------------------------------------------------
5616 # Data
5617 outfile.write(f"{indent * 1}<g id='variants'>\n")
5619 for i_s, synteny in enumerate(plot_data):
5620 s_data = plot_data[synteny]
5621 if len(plot_data) <= max_blocks:
5622 s_x1, s_x2 = s_data["x1"], s_data["x2"]
5623 else:
5624 s_x1, s_x2 = x_axis_x1, x_axis_x2
5625 s_len = s_data["length"]
5627 variants = plot_data[synteny]["variants"].values()
5628 variants_order = [v for v in variants if v["type"] == "presence_absence"]
5629 variants_order += [v for v in variants if v["type"] == "structural"]
5630 variants_order += [v for v in variants if v["type"] == "snp"]
5631 opacity = 0.50
5633 for v_data in variants_order:
5634 # If we're plotting according to pangenome coordinates
5635 # we use the panGWAS green color
5636 if len(plot_data) >= max_blocks:
5637 f = "#356920"
5638 # Otherwise we alternate between the classic blue and orange
5639 else:
5640 f = "#242bbd" if i_s % 2 == 0 else "#e37610"
5641 vy = linear_scale(v_data["-log10(p)"], min_log10, max_log10, y_axis_y2, y_axis_y1)
5642 variant = v_data["variant"].replace("'", "")
5644 # In many browsers, the title will become hover text!
5645 hover_text = []
5646 for k,v in v_data.items():
5647 if "coords" not in k:
5648 text = f"{k}: {v}"
5649 else:
5650 coords = ", ".join(["-".join([str(xs) for xs in x]) for x in v])
5651 text = f"{k}: {coords}"
5652 hover_text.append(text)
5653 hover_text = "\n".join(hover_text)
5654 title = f"{indent * 4}<title>{hover_text}</title>"
5656 outfile.write(f"{indent * 2}<g id='{variant}'>\n")
5658 if len(observed_clusters) == 1:
5659 s_start = start_coord
5660 s_end = end_coord
5661 else:
5662 s_start = 1
5663 s_end = s_len
5665 # Decide on the coordinate system
5666 if len(plot_data) > max_blocks:
5667 coordinate_system = "pangenome"
5668 s_start, s_end = 0, pangenome_length
5669 else:
5670 coordinate_system = "synteny"
5672 pos = v_data[f"{coordinate_system}_coords"][0][0]
5673 vx = linear_scale(pos, s_start, s_end, s_x1, s_x2)
5675 # Circle
5676 if v_data["type"] == "snp":
5677 circle = f"{indent * 3}<circle cx='{vx}' cy='{vy}' r='{radius}' style='fill:{f};fill-opacity:{opacity}'>"
5678 outfile.write(circle + "\n")
5679 outfile.write(title + "\n")
5680 outfile.write(f"{indent * 3}</circle>" + "\n")
5682 # Square
5683 elif v_data["type"] == "presence_absence":
5684 pos = v_data[f"{coordinate_system}_coords"][0][0]
5685 vx = linear_scale(pos, s_start, s_end, s_x1, s_x2)
5686 rect = f"{indent * 3}<rect width='{radius*2}' height='{radius*2}' x='{vx}' y='{vy - radius}' style='fill:{f};fill-opacity:{opacity}'>"
5687 outfile.write(rect + "\n")
5688 outfile.write(title + "\n")
5689 outfile.write(f"{indent * 3}</rect>" + "\n")
5691 # Diamond
5692 elif v_data["type"] == "structural":
5693 rx, ry = vx + radius, vy + radius
5694 rect = f"{indent * 3}<rect transform='rotate(-45, {rx}, {ry})' width='{radius*2}' height='{radius*2}' x='{vx}' y='{vy - radius}' style='fill:{f};fill-opacity:{opacity}'>"
5695 outfile.write(rect + "\n")
5696 outfile.write(title + "\n")
5697 outfile.write(f"{indent * 3}</rect>" + "\n")
5699 outfile.write(f"{indent * 2}</g>\n")
5700 outfile.write(f"{indent * 1}</g>\n")
5702 # p-value threshold
5703 outfile.write(f"{indent * 1}<g id='alpha'>\n")
5704 x1, x2 = x_axis_x1, x_axis_x2
5705 y = linear_scale(alpha_log10, min_log10, max_log10, y_axis_y2, y_axis_y1)
5706 line = f"{indent * 2}<path d='M {x1} {y} H {x2}' style='stroke:grey;stroke-width:1;fill:none' stroke-dasharray='4 4'/>"
5707 outfile.write(line + "\n")
5708 outfile.write(f"{indent * 1}</g>\n")
5710 outfile.write("</svg>" + "\n")
5712 png_path = os.path.join(output_dir, f"{prefix}plot.png")
5713 logging.info(f"Rendering output png ({width}x{height}): {png_path}")
5714 cairosvg.svg2png(url=svg_path, write_to=png_path, output_width=width, output_height=height, scale=png_scale)
5716 return svg_path
5719def heatmap(
5720 tree: str=None,
5721 tree_format: str="newick",
5722 gwas: str=None,
5723 rtab: str=None,
5724 outdir: str = ".",
5725 prefix: str=None,
5726 focal: str=None,
5727 min_score: float = None,
5728 tree_width=100,
5729 margin=20,
5730 root_branch=10,
5731 tip_pad=10,
5732 font_size=16,
5733 font_family="Roboto",
5734 png_scale=2.0,
5735 heatmap_scale=1.5,
5736 palette={"presence_absence": "#140d91", "structural": "#7a0505", "snp": "#0e6b07", "unknown": "#636362"},
5737 args: str = None,
5738 ):
5739 """
5740 Plot a tree and/or a heatmap of variants.
5742 Takes as input a newick tree and/or a table of variants. The table can be either
5743 an Rtab file, or the locus effects TSV output from the gwas subcommand.
5744 If both a tree and a table are provided, the tree will determine the sample order
5745 and arrangement. If just a table is provided, sample order will follow the
5746 order of the sample columns. A TXT of focal sample IDs can also be supplied
5747 with one sample ID per line. Outputs a plot in SVG and PNG format.
5749 >>> plot(tree="tree.rooted.nwk")
5750 >>> plot(rtab="combine.Rtab")
5751 >>> plot(gwas="resistant.locus_effects.significant.tsv")
5752 >>> plot(tree="tree.rooted.nwk", rtab="combine.Rtab", focal="focal.txt")
5753 >>> plot(tree="tree.rooted.nwk", gwas="resistant.locus_effects.significant.tsv")
5754 >>> plot(tree="tree.rooted.nwk, tree_width=500)
5756 :param tree: Path to newick tree.
5757 :param gwas: Path to tsv table of locus effects from gwas subcommand.
5758 :param rtab: Path to Rtab file of variants.
5759 :param prefix: Output prefix.
5760 :param focal: Path to text file of focal sample IDs to highlight.
5761 :param min_score: Filter GWAS variants for a minimum score.
5762 :param tree_width: Width in pixels of tree.
5763 :param margin: Size in pixels of plot margins.
5764 :param root_branch: Width in pixels of root branch.
5765 :param tip_pad: Pad in pixels between tip labels and branches.
5766 :param font_size: Font size.
5767 :param font_family: Font family.
5768 :param png_scale: Float that adjusts png scale relative to svg.
5769 :param heatmap_scale: Float that adjusts heatmap box scale relative to text.
5770 :param palette: Dict of variant types to colors.
5771 """
5773 tree_path, gwas_path, focal_path, rtab_path, output_dir = tree, gwas, focal, rtab, outdir
5774 prefix = f"{prefix}." if prefix != None else ""
5776 if not tree_path and not gwas_path and not rtab_path:
5777 msg = "Either a tree (--tree), a GWAS table (--gwas) or an Rtab --rtab) must be supplied."
5778 logging.error(msg)
5779 raise Exception(msg)
5781 elif gwas_path and rtab_path:
5782 msg = "A GWAS table (--gwas) is mutually exclusive with an Rtab (--rtab) file."
5783 logging.error(msg)
5784 raise Exception(msg)
5786 # Check output directory
5787 check_output_dir(output_dir)
5789 logging.info(f"Importing cairo.")
5790 import cairosvg
5792 # Todo: If we want to color by beta/p-value
5793 # logging.info(f"Importing matplotlib color palettes.")
5794 # import matplotlib as mpl
5795 # import matplotlib.pyplot as plt
5796 # from matplotlib.colors import rgb2hex
5798 # -------------------------------------------------------------------------
5799 # Parse Focal Samples
5801 focal = []
5802 if focal_path:
5803 logging.info(f"Parsing focal samples: {focal_path}")
5804 with open(focal_path) as infile:
5805 for line in infile:
5806 sample = line.strip()
5807 if sample != "" and sample not in focal:
5808 focal.append(sample)
5810 # -------------------------------------------------------------------------
5811 # Parse Input Tree
5813 if tree_path:
5814 logging.info(f"Parsing {tree_format} tree: {tree_path}")
5815 tree = read_tree(tree=tree_path, tree_format=tree_format)
5816 tree.is_rooted = True
5817 tree.ladderize()
5818 else:
5819 root_branch = 0
5820 tree_width = 0
5822 # -------------------------------------------------------------------------
5823 # Node labels
5825 node_labels = {}
5826 node_labels_wmax = 0
5827 node_labels_wmax_text = None
5828 node_labels_hmax = 0
5829 node_labels_hmax_text = None
5830 all_samples = []
5831 tree_tips = []
5832 tree_nodes = []
5834 # Option 1: Tips from the tree
5835 if tree_path:
5836 # Cleanup up labels, dendropy sometimes inserts quotations
5837 logging.info(f"Parsing tree labels.")
5838 node_i = 0
5839 for node in list(tree.preorder_node_iter()):
5840 # Tip labels
5841 if node.taxon:
5842 node.label = str(node.taxon)
5843 logging.debug(f"Tree Tip: {node.label}")
5844 # Internal node labels
5845 else:
5846 if not node.label or "NODE_" not in node.label:
5847 if node_i == 0:
5848 logging.warning(f"Using 'NODE_<i> nomenclature for internal node labels.")
5849 node.label = f"NODE_{node_i}"
5850 node_i += 1
5851 node.label = node.label.replace("\"", "").replace("'", "")
5853 # Keep a list of tip labels
5854 tree_tips = [n.label for n in tree.leaf_node_iter()]
5855 tree_nodes = [n.label for n in tree.preorder_node_iter()]
5856 all_samples = copy.deepcopy(tree_tips)
5858 # Option 2: Tips from GWAS columns
5859 if gwas_path:
5860 logging.info(f"Parsing gwas labels: {gwas_path}")
5861 with open(gwas_path) as infile:
5862 header = infile.readline().strip().split("\t")
5863 # Sample IDs come after the dbxref_alt column (if we used summarized clusters)
5864 # otherwise they come after the "score" column
5865 if "dbxref_alt" in header:
5866 samples = header[header.index("dbxref_alt")+1:]
5867 else:
5868 samples = header[header.index("score")+1:]
5869 for sample in samples:
5870 if sample not in all_samples:
5871 all_samples.append(sample)
5873 # Option 3: Tips from Rtab columns
5874 if rtab_path:
5875 logging.info(f"Parsing rtab labels: {rtab_path}")
5876 with open(rtab_path) as infile:
5877 header = infile.readline().strip().split("\t")
5878 # Sample IDs are all columns after the first
5879 samples = header[1:]
5880 for sample in samples:
5881 if sample not in all_samples:
5882 all_samples.append(sample)
5884 # Figure out the largest sample label, we'll need to know
5885 # this to position elements around it.
5886 logging.info(f"Calculating sample label dimensions.")
5888 for node in tqdm(all_samples):
5889 label = text_to_path(text=node, size=font_size, family=font_family)
5890 node_labels[node] = label
5891 w, h = label["w"], label["h"]
5892 if w > node_labels_wmax:
5893 node_labels_wmax = w
5894 node_labels_wmax_text = node
5895 if h > node_labels_hmax:
5896 node_labels_hmax = h
5897 node_labels_hmax_text = node
5899 node_labels_wmax = math.ceil(node_labels_wmax)
5900 node_labels_hmax = math.ceil(node_labels_hmax)
5902 logging.info(f"The widest sample label is {math.ceil(node_labels_wmax)}px: {node_labels_wmax_text}")
5903 logging.info(f"The tallest sample label is {math.ceil(node_labels_hmax)}px: {node_labels_hmax_text}")
5905 # -------------------------------------------------------------------------
5906 # Variant Labels
5908 # Figure out the largest variant label for the heatmap
5910 variant_labels_wmax = 0
5911 variant_labels_wmax_text = ""
5912 variants = OrderedDict()
5914 # Option 1. Variants from GWAS rows
5916 # The score can range from -1 to +1
5917 # We will make a gradient palette based on the score for the fill
5918 score_min = None
5919 score_max = None
5920 if gwas_path:
5921 logging.info(f"Parsing GWAS variants: {gwas_path}")
5922 with open(gwas_path) as infile:
5923 header = [line.strip() for line in infile.readline().split("\t")]
5924 for line in infile:
5925 row = [l.strip() for l in line.split("\t")]
5926 data = {k:v for k,v in zip(header, row)}
5927 variant = f"{data['variant']}"
5928 if "score" in data:
5929 score = float(data["score"])
5930 if min_score != None and score < min_score:
5931 logging.info(f"Excluding {variant} due to score {score} < {min_score}")
5932 continue
5933 if score < 0:
5934 logging.warning(f"Converting {variant} score of {score} to 0.00")
5935 score = 0.0
5936 data["score"] = score
5937 score_min = score if score_min == None else min(score_min, score)
5938 score_max = score if score_max == None else max(score_max, score)
5939 variants[variant] = {"data": data }
5941 # Option 2. Variants from Rtab rows
5942 elif rtab_path:
5943 logging.info(f"Parsing Rtab variants: {rtab_path}")
5944 with open(rtab_path) as infile:
5945 header = infile.readline().strip().split("\t")
5946 lines = infile.readlines()
5947 for line in tqdm(lines):
5948 row = [l.strip() for l in line.split("\t")]
5949 variant = f"{row[0]}"
5950 data = {col:val for col,val in zip(header,row)}
5951 variants[variant] = {"data": data }
5953 # Render variant label text to svg paths
5954 variant_types = set()
5955 if len(variants) > 0:
5956 logging.info(f"Calculating variant label dimensions.")
5957 for variant in tqdm(variants):
5958 try:
5959 variant_type = variant.split("|")[1].split(":")[0]
5960 except IndexError:
5961 variant_type = "unknown"
5962 # Space out content around the pipe delim
5963 text = variant.replace("|", " | ").replace("presence_absence", "present")
5964 label = text_to_path(text, size=font_size, family=font_family)
5965 w = label["w"]
5966 if w > variant_labels_wmax:
5967 variant_labels_wmax = w
5968 variant_labels_wmax_text = text
5969 variants[variant]["text"] = text
5970 variants[variant]["type"] = variant_type
5971 variants[variant]["label"] = label
5972 variant_types.add(variant_type)
5973 variants[variant]["fill"] = palette[variant_type]
5975 logging.info(f"Grouping variants by type.")
5976 variants_order = OrderedDict()
5977 for variant_type in palette:
5978 for variant,variant_data in variants.items():
5979 if variant_type != variant_data["type"]: continue
5980 variants_order[variant] = variant_data
5982 variants = variants_order
5984 if variant_labels_wmax_text:
5985 logging.info(f"The widest variant label is {math.ceil(variant_labels_wmax)}px: {variant_labels_wmax_text}")
5986 # Order variant types by palette
5987 variant_types = [v for v in palette if v in variant_types]
5989 # -------------------------------------------------------------------------
5990 # Initialize plot data
5992 plot_data = OrderedDict()
5994 # If tree provided, orient plot around it
5995 if tree_path:
5996 for node in tree.preorder_node_iter():
5997 plot_data[node.label] = {
5998 "x": None,
5999 "y": None,
6000 "parent": node.parent_node.label if node.parent_node else None,
6001 "children": [c.label for c in node.child_nodes()],
6002 "label": node_labels[node.label] if node.label in node_labels else OrderedDict(),
6003 "heatmap": OrderedDict()
6004 }
6006 # Check if any variant samples are missing from the tree tips
6007 for node in all_samples:
6008 if node not in plot_data:
6009 plot_data[node] = {
6010 "x": None,
6011 "y": None,
6012 "parent": None,
6013 "children": [],
6014 "label": node_labels[node] if node in node_labels else OrderedDict(),
6015 "heatmap": OrderedDict()
6016 }
6018 # -------------------------------------------------------------------------
6019 # Node coordinates
6021 logging.info(f"Calculating node coordinates.")
6023 # Identify the branch that sticks out the most, this will
6024 # determine tip label placement
6025 node_xmax = 0
6026 node_i = 0
6027 node_xmax_text = node_labels_wmax_text
6029 # X Coordinates: Distance to the root
6030 if tree_path:
6031 for node,dist in zip(
6032 tree.preorder_node_iter(),
6033 tree.calc_node_root_distances(return_leaf_distances_only=False)
6034 ):
6035 plot_data[node.label]["x"] = dist
6036 if dist > node_xmax:
6037 node_xmax = dist
6038 node_xmax_text = node.label
6040 logging.info(f"The most distant tree node is: {node_xmax_text}")
6042 # Y Coordinates: Start with tips
6043 for node in list(tree.leaf_nodes()):
6044 plot_data[node.label]["y"] = node_i
6045 node_i += 1
6047 # Y Coordinates: Place internal nodes at the midpoint of (immediate?) children
6048 for node in tree.postorder_internal_node_iter():
6049 children = [c.label for c in node.child_nodes()]
6050 children_y = [plot_data[c]["y"] for c in children]
6051 y = sum(children_y) / len(children)
6052 plot_data[node.label]["y"] = y
6054 # Set coords for non-tree samples
6055 for node in plot_data:
6056 if plot_data[node]["x"] == None:
6057 plot_data[node]["x"] = 0
6058 plot_data[node]["y"] = node_i
6059 node_i += 1
6061 # -------------------------------------------------------------------------
6062 # Coordinate scaling
6064 # Rescale coordinates to pixels
6065 # X Scaling: Based on the user's requested tree width
6066 # Y Scaling: Based on the number of nodes in the tree
6068 logging.info(f"Rescaling branch lengths to pixels.")
6069 node_xmax = max([c["x"] for c in plot_data.values()])
6070 node_ymax = max([c["y"] for c in plot_data.values()])
6072 # Heatmap dimensions for scaling
6073 heatmap_s = (node_labels_hmax * heatmap_scale)
6074 heatmap_pad = heatmap_s / 4
6075 heatmap_r = 2
6076 heatmap_h = (heatmap_s * len(all_samples)) + (heatmap_pad * (len(all_samples) - 1))
6077 logging.info(f"Heatmap boxes will be {heatmap_s}px wide with {heatmap_pad}px of padding.")
6079 # Rescale the x coordinates based on pre-defined maximum width
6080 if node_xmax > 0:
6081 x_scale = math.floor(tree_width / node_xmax)
6082 else:
6083 x_scale = 1
6084 # Rescale y coordinates based on the heatmap and text label dimensions
6085 if len(all_samples) > 1:
6086 tree_h = math.ceil((heatmap_s * len(all_samples)) + (heatmap_pad * (len(all_samples) - 1)) - heatmap_s)
6087 y_scale = math.floor(tree_h / node_ymax)
6088 else:
6089 tree_h = math.ceil((heatmap_s * len(all_samples)))
6090 y_scale = math.floor(tree_h)
6092 # Rescale the node coordinates, and position them
6093 logging.info(f"Positioning nodes.")
6094 for node, data in plot_data.items():
6095 # Locate the node's x coordinate, based on the distance to root
6096 plot_data[node]["x"] = math.floor(margin + root_branch + (data["x"] * x_scale))
6097 # Locate the node's y coordinate, if heatmap is available
6098 if variant_labels_wmax > 0:
6099 y = math.floor(margin + variant_labels_wmax + heatmap_s + (data["y"] * y_scale))
6100 else:
6101 y = math.floor(margin + data["y"] * y_scale)
6102 plot_data[node]["y"] = y
6104 # Get the new X coord where tips start
6105 node_xmax = plot_data[node_xmax_text]["x"]
6106 node_labels_x = node_xmax + tip_pad
6108 # Rescale the node label coordinates
6109 logging.info(f"Positioning tip labels.")
6110 for node in all_samples:
6111 data = plot_data[node]
6112 # Adjust the label y position to center on the first glyph
6113 first_glyph = data["label"]["glyphs"][0]
6114 first_glyph_h = first_glyph["h"]
6115 label_y = data["y"] + (first_glyph_h / 2)
6116 plot_data[node]["label"]["y"] = label_y
6117 label_w = plot_data[node]["label"]["w"]
6118 for i,glyph in enumerate(data["label"]["glyphs"]):
6119 # Left alignment
6120 #plot_data[node]["label"]["glyphs"][i]["x"] = node_labels_x + glyph["x"]
6121 # Right alignment
6122 plot_data[node]["label"]["glyphs"][i]["x"] = node_labels_x + (node_labels_wmax - label_w) + glyph["x"]
6123 plot_data[node]["label"]["glyphs"][i]["y"] = label_y
6125 # Heatmap labels
6126 logging.info(f"Positioning variant labels.")
6127 heatmap_x = node_labels_x + node_labels_wmax + tip_pad
6128 heatmap_w = 0
6130 # Add spaces between the different variant types
6131 prev_var_type = None
6132 prev_x = None
6134 for variant in variants:
6135 # Adjust the position to center on the first glyph
6136 data = variants[variant]["data"]
6137 variant_type = variants[variant]["type"]
6138 glyphs = variants[variant]["label"]["glyphs"]
6139 first_glyph = glyphs[0]
6140 first_glyph_h = first_glyph["h"]
6141 first_glyph_offset = first_glyph_h + ((heatmap_s - first_glyph_h)/ 2)
6143 # Set the absolute position of the label
6144 if prev_x == None:
6145 x = heatmap_x + first_glyph_offset
6146 elif prev_var_type != None and variant_type != prev_var_type:
6147 x += heatmap_s + (heatmap_pad * 3)
6148 else:
6149 x = prev_x + (heatmap_s + heatmap_pad)
6151 # Set the absolute position of the box
6152 variants[variant]["box"] = {"x": x - first_glyph_offset}
6153 prev_x = x
6154 prev_var_type = variant_type
6155 # Add an extra 2 pixels to make it reach just beyond the final box
6156 heatmap_w = (2 + heatmap_s + x - first_glyph_offset) - heatmap_x
6158 y = margin + variant_labels_wmax
6159 variants[variant]["label"]["x"] = x
6160 variants[variant]["label"]["y"] = y
6161 for i_g,g in enumerate(glyphs):
6162 variants[variant]["label"]["glyphs"][i_g]["x"] = x + g["x"]
6163 variants[variant]["label"]["glyphs"][i_g]["y"] = y
6165 # -------------------------------------------------------------------------
6166 # Heatmap Table Data
6168 score_min_alpha, score_max_alpha = 0.10, 1.0
6170 logging.info(f"Positioning heatmap boxes.")
6171 for i,variant in enumerate(variants):
6172 variant_type = variants[variant]["type"]
6173 data = variants[variant]["data"]
6174 label = variants[variant]["label"]
6175 x = variants[variant]["box"]["x"]
6176 for node in all_samples:
6177 # Ex. a sample in tree but not heatmap
6178 if node not in data: continue
6179 v = data[node]
6180 v = int(v) if v.isdigit() else str(v)
6181 fill_opacity = 1.0
6182 # Light grey for missing values (".")
6183 score = data["score"] if "score" in data else None
6184 if v == 1:
6185 box_fill = variants[variant]["fill"]
6186 # Use score range if possible
6187 if score_min != None and "score" in data:
6188 score = float(data["score"])
6189 # Option 1: Fixed range from 0 to 1
6190 fill_opacity = linear_scale(score, 0, 1.0, score_min_alpha, score_max_alpha)
6191 # Option 2. Relative range from score_min to score_max
6192 # fill_opacity = linear_scale(score, score_min, score_max, score_min_alpha, score_max_alpha)
6193 box_stroke = "black"
6194 elif v == 0:
6195 box_fill = "white"
6196 box_stroke = "black"
6197 else:
6198 box_fill = "none"
6199 box_stroke = "none"
6200 y = plot_data[node]["y"] - (heatmap_s / 2)
6201 d = {
6202 "v": v, "x": x, "y": y, "w": heatmap_s, "h": heatmap_s,
6203 "r": heatmap_r, "fill": box_fill, "fill_opacity": fill_opacity, "stroke": box_stroke,
6204 "hovertext": data,
6205 }
6206 plot_data[node]["heatmap"][variant] = d
6208 # -----------------------------------------------------------------------------
6209 # Legend
6211 legend_x = (heatmap_x + heatmap_w + heatmap_s)
6212 legend_y = margin + variant_labels_wmax + tip_pad
6213 legend_w = 0
6214 legend_h = 0
6215 legend = OrderedDict()
6216 legend_labels_hmax = 0
6217 legend_labels_wmax = 0
6218 legend_title_wmax = 0
6219 legend_title_hmax = 0
6221 if score_min != None and score_max != None:
6222 logging.info(f"Drawing legend at: {legend_x}, {legend_y}")
6223 legend_h = (heatmap_s * 3) + (heatmap_pad * 2)
6224 # TBD: Think about negative score handling?
6225 # Option 1. Fixed Values
6226 legend_values = ["0.00", "0.25", "0.50", "0.75", "1.00"]
6227 # # Option 2. Relative to score max
6228 # interval = score_max / 4
6229 # legend_values = [round(i * interval,2) for i in range(0,4)] + [round(score_max, 2)]
6230 # legend_values = ["{:.2f}".format(v) for v in legend_values]
6232 # Tick labels: convert to path
6233 tick_labels = [text_to_path(text, size=font_size * 0.50, family=font_family) for text in legend_values]
6234 for label in tick_labels:
6235 legend_labels_hmax = max(legend_labels_hmax, label["h"])
6236 legend_labels_wmax = max(legend_labels_wmax, label["w"])
6237 title_label = text_to_path("Score", size=font_size * 0.75, family=font_family)
6238 legend_title_hmax = max(legend_title_hmax, label["h"])
6239 legend_title_wmax = max(legend_title_wmax, label["w"])
6242 # Width of legend + tick len + pad + label wmax + pad
6243 entry_w = heatmap_s + heatmap_pad + (heatmap_pad / 2) + legend_labels_wmax + heatmap_s
6244 if (legend_title_wmax + heatmap_s) > entry_w:
6245 entry_w = legend_title_wmax + heatmap_s
6247 for i_v, variant_type in enumerate(variant_types):
6248 box = {
6249 "x": legend_x + (i_v * entry_w),
6250 "y": legend_y,
6251 "w": heatmap_s,
6252 "h": (heatmap_s * 3) + (heatmap_pad * 2),
6253 "fill": f"url(#{variant_type}_gradient)",
6254 }
6255 title = copy.deepcopy(title_label)
6256 title["x"] = box["x"]
6257 title["y"] = box["y"] - tip_pad
6258 for i_g,g in enumerate(title["glyphs"]):
6259 title["glyphs"][i_g]["x"] = title["x"] + g["x"]
6260 title["glyphs"][i_g]["y"] = title["y"]
6261 ticks = []
6262 for i_t,text in enumerate(reversed(legend_values)):
6263 x1 = box["x"] + box["w"]
6264 x2 = x1 + heatmap_pad
6265 y = legend_y + (i_t * (legend_h / (len(legend_values) - 1)))
6266 label = copy.deepcopy(list(reversed(tick_labels))[i_t])
6267 label["x"] = x2 + (heatmap_pad / 2)
6268 label["y"] = y
6270 # Adjust the label y position to center on first numeric char
6271 center_glyph_h = label["glyphs"][0]["h"]
6272 label["y"] = y + (center_glyph_h / 2)
6273 for i_g,g in enumerate(label["glyphs"]):
6274 label["glyphs"][i_g]["x"] = label["x"] + g["x"]
6275 label["glyphs"][i_g]["y"] = label["y"]
6277 tick = {"text": text, "x1": x1, "x2":x2, "y": y, "label": copy.deepcopy(label) }
6278 ticks.append(tick)
6280 legend[variant_type] = {"title": title, "box": box, "ticks": ticks}
6281 legend_w += entry_w
6283 # -----------------------------------------------------------------------------
6284 # Draw Tree
6286 logging.info(f"Calculating final image dimensions.")
6287 # Final image dimensions
6288 if len(legend) > 0:
6289 logging.info("Final dimensions with legend.")
6290 width = math.ceil(legend_x + legend_w + margin)
6291 height = math.ceil(margin + variant_labels_wmax + tip_pad + tree_h + heatmap_s + margin)
6292 elif heatmap_w > 0 and heatmap_h > 0:
6293 logging.info("Final dimensions with heatmap.")
6294 width = math.ceil(heatmap_x + heatmap_w + margin)
6295 height = math.ceil(margin + variant_labels_wmax + tip_pad + tree_h + heatmap_s + margin)
6296 else:
6297 logging.info("Final dimensions without heatmap.")
6298 width = math.ceil(node_labels_x + node_labels_wmax + margin)
6299 height = math.ceil(tree_h + (2*margin))
6301 svg_path = os.path.join(output_dir, f"{prefix}plot.svg")
6302 logging.info(f"Rendering output svg ({width}x{height}): {svg_path}")
6303 with open(svg_path, 'w') as outfile:
6304 header = textwrap.dedent(
6305 f"""\
6306 <svg
6307 version="1.1"
6308 xmlns="http://www.w3.org/2000/svg"
6309 xmlns:xlink="http://www.w3.org/1999/xlink"
6310 preserveAspectRatio="xMidYMid meet"
6311 width="{width}"
6312 height="{height}"
6313 viewbox="0 0 {width} {height}">
6314 """)
6316 outfile.write(header.strip() + "\n")
6318 branches = []
6319 tip_circles = []
6320 tip_dashes = []
6321 tip_labels = []
6322 variant_labels = OrderedDict()
6323 heatmap_boxes = OrderedDict()
6324 focal_boxes = OrderedDict()
6325 legend_entries = OrderedDict()
6327 # Variant Heatmap Labels
6328 for variant in variants:
6329 label = variants[variant]["label"]
6330 rx, ry = label["x"], label["y"]
6331 if variant not in variant_labels:
6332 variant_labels[variant] = []
6333 heatmap_boxes[variant] = []
6334 for g in label["glyphs"]:
6335 line = f"<path transform='rotate(-90, {rx}, {ry}) translate({g['x']},{g['y']})' d='{g['d']}' />"
6336 variant_labels[variant].append(line)
6338 for node,data in plot_data.items():
6339 logging.debug(f"node: {node}, data: {data}")
6341 # Root branch, add custom little starter branch
6342 if not data["parent"]:
6343 cx, cy = data['x'], data['y']
6344 px, py = data['x'] - root_branch, cy
6345 # Non-root
6346 else:
6347 p_data = plot_data[data["parent"]]
6348 cx, cy, px, py = data['x'], data['y'], p_data['x'], p_data['y']
6350 # Draw tree lines
6351 if node in tree_nodes:
6352 line = f"<path d='M {cx} {cy} H {px} V {py}' style='stroke:black;stroke-width:2;fill:none'/>"
6353 branches.append(line)
6355 # Focal box
6356 if node in focal:
6357 # Draw the box to the midway point between it and the next label
6358 if len(plot_data) == 1:
6359 rh = node_labels_hmax
6360 else:
6361 # Figure out adjacent tip
6362 tip_i = all_samples.index(node)
6363 adjacent_tip = all_samples[tip_i+1] if tip_i == 0 else all_samples[tip_i-1]
6364 rh = abs(plot_data[adjacent_tip]["y"] - data['y'])
6365 ry = data['y'] - (rh /2)
6366 rx = node_labels_x
6367 rw = node_labels_wmax + tip_pad + heatmap_w
6369 rect = f"<rect width='{rw}' height='{rh}' x='{rx}' y='{ry}' rx='1' ry='1' style='fill:grey;fill-opacity:.40'/>"
6370 focal_boxes[node] = [rect]
6372 # Dashed line: tree -> tip
6373 if node in tree_tips:
6374 lx = node_labels_x + (node_labels_wmax - data["label"]["w"])
6375 line = f"<path d='M {cx + 4} {cy} H {lx - 4}' style='stroke:grey;stroke-width:1;fill:none' stroke-dasharray='4 4'/>"
6376 tip_dashes.append(line)
6378 # Tip circles
6379 if node in focal:
6380 circle = f"<circle cx='{cx}' cy='{cy}' r='4' style='fill:black;stroke:black;stroke-width:1' />"
6381 tip_circles.append(circle)
6383 # Sample Labels
6384 if node in all_samples:
6385 for g in data["label"]["glyphs"]:
6386 line = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />"
6387 tip_labels.append(line)
6388 # Heatmap
6389 if node in all_samples:
6390 for variant,d in data["heatmap"].items():
6392 hover_text = "\n".join([f"{k}: {v}" for k,v in d["hovertext"].items() if k not in all_samples and not k.endswith("_alt")])
6393 rect = f"<rect width='{d['w']}' height='{d['h']}' x='{d['x']}' y='{d['y']}' rx='{d['r']}' ry='{d['r']}' style='fill:{d['fill']};fill-opacity:{d['fill_opacity']};stroke-width:1;stroke:{d['stroke']}'>"
6394 title = f"<title>{hover_text}</title>"
6395 heatmap_boxes[variant] += [rect, title, "</rect>"]
6397 # Legend
6398 for variant_type in legend:
6399 legend_entries[variant_type] = []
6400 # Gradient
6401 fill = palette[variant_type]
6402 gradient = f"""
6403 <linearGradient id="{variant_type}_gradient" x1="0" x2="0" y1="0" y2="1" gradientTransform="rotate(180 0.5 0.5)">
6404 <stop offset="0%" stop-color="{fill}" stop-opacity="{score_min_alpha}" />
6405 <stop offset="100%" stop-color="{fill}" stop-opacity="{score_max_alpha}" />
6406 </linearGradient>"""
6407 legend_entries[variant_type].append(gradient)
6408 # Legend Title
6409 title = legend[variant_type]["title"]
6410 for g in title["glyphs"]:
6411 label = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />"
6412 legend_entries[variant_type].append(label)
6413 # Legend box
6414 d = legend[variant_type]["box"]
6415 rect = f"<rect width='{d['w']}' height='{d['h']}' x='{d['x']}' y='{d['y']}' style='fill:{d['fill']};stroke-width:1;stroke:black'/>"
6416 legend_entries[variant_type].append(rect)
6417 # Legend Ticks
6418 for t in legend[variant_type]["ticks"]:
6419 # Tick line
6420 line = f"<path d='M {t['x1']} {t['y']} H {t['x2']} V {t['y']}' style='stroke:black;stroke-width:1;fill:none'/>"
6421 legend_entries[variant_type].append(line)
6422 # Tick Label
6423 for g in t["label"]["glyphs"]:
6424 label = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />"
6425 legend_entries[variant_type].append(label)
6427 # Draw elements in groups
6428 indent = " "
6429 outfile.write(f"{indent * 1}<g id='Plot'>" + "\n")
6431 # White canvas background
6432 background = f"<rect width='{width}' height='{height}' x='0' y='0' style='fill:white;stroke-width:1;stroke:white'/>"
6433 outfile.write(f"{indent * 2}<g id='Background'>\n{indent * 3}{background}\n{indent * 2}</g>" + "\n")
6435 # Focal boxes start
6436 outfile.write(f"{indent * 2}<g id='Focal'>" + "\n")
6437 # Focal Boxes
6438 outfile.write(f"{indent * 3}<g id='Focal Boxes'>\n")
6439 for sample,boxes in focal_boxes.items():
6440 outfile.write(f"{indent * 4}<g id='{sample}'>\n")
6441 outfile.write("\n".join([f"{indent * 5}{b}" for b in boxes]) + "\n")
6442 outfile.write(f"{indent * 4}</g>\n")
6443 outfile.write(f"{indent * 3}</g>\n")
6444 # Focal boxes end
6445 outfile.write(f"{indent * 2}</g>\n")
6447 # Tree Start
6448 outfile.write(f"{indent * 2}<g id='Tree'>" + "\n")
6449 # Branches
6450 outfile.write(f"{indent * 3}<g id='Branches'>\n")
6451 outfile.write("\n".join([f"{indent * 4}{b}" for b in branches]) + "\n")
6452 outfile.write(f"{indent * 3}</g>\n")
6453 # Tip Circles
6454 outfile.write(f"{indent * 3}<g id='Tip Circles'>\n")
6455 outfile.write("\n".join([f"{indent * 4}{c}" for c in tip_circles]) + "\n")
6456 outfile.write(f"{indent * 3}</g>\n")
6457 # Tip Dashes
6458 outfile.write(f"{indent * 3}<g id='Tip Dashes'>\n")
6459 outfile.write("\n".join([f"{indent * 4}{d}" for d in tip_dashes]) + "\n")
6460 outfile.write(f"{indent * 3}</g>\n")
6461 # Tip Labels
6462 outfile.write(f"{indent * 3}<g id='Tip Labels'>\n")
6463 outfile.write("\n".join([f"{indent * 4}{t}" for t in tip_labels]) + "\n")
6464 outfile.write(f"{indent * 3}</g>\n")
6465 # Tree End
6466 outfile.write(f"{indent}</g>\n")
6468 # Heatmap Start
6469 outfile.write(f"{indent}<g id='Heatmap'>" + "\n")
6470 # Variant Labels
6471 outfile.write(f"{indent * 3}<g id='Variant Labels'>\n")
6472 for variant,paths in variant_labels.items():
6473 outfile.write(f"{indent * 4}<g id='{variant}'>\n")
6474 outfile.write("\n".join([f"{indent * 5}{p}" for p in paths]) + "\n")
6475 outfile.write(f"{indent * 4}</g>\n")
6476 outfile.write(f"{indent * 3}</g>\n")
6477 # Heatmap Boxes
6478 outfile.write(f"{indent * 3}<g id='Heatmap Boxes'>\n")
6479 for variant,boxes in heatmap_boxes.items():
6480 outfile.write(f"{indent * 4}<g id='{variant}'>\n")
6481 outfile.write("\n".join([f"{indent * 5}{b}" for b in boxes]) + "\n")
6482 outfile.write(f"{indent * 4}</g>\n")
6483 outfile.write(f"{indent * 3}</g>\n")
6484 # Heatmap End
6485 outfile.write(f"{indent * 2}</g>\n")
6487 # Legend
6488 if score_min != None and score_max != None:
6489 outfile.write(f"{indent * 2}<g id='Legend'>\n")
6490 for variant_type in legend_entries:
6491 outfile.write(f"{indent * 3}<g id='{variant_type}'>\n")
6492 for element in legend_entries[variant_type]:
6493 outfile.write(f"{indent * 4}{element}\n")
6494 outfile.write(f"{indent * 3}</g>\n")
6495 outfile.write(f"{indent * 4}</g>\n")
6497 # outfile.write(f"{indent * 3}<g id='{variant_type}'>")
6498 # fill = palette[variant_type]
6499 # gradient = f"""
6500 # <linearGradient id="{variant_type}_gradient" x1="0" x2="0" y1="0" y2="1" gradientTransform="rotate(180 0.5 0.5)">
6501 # <stop offset="0%" stop-color="{fill}" stop-opacity="{score_min_alpha}" />
6502 # <stop offset="100%" stop-color="{fill}" stop-opacity="{score_max_alpha}" />
6503 # </linearGradient>"""
6504 # outfile.write(f"{indent * 4}{gradient}\n")
6505 # outfile.write(f"{indent * 3}</g>\n")
6506 # # Box
6507 # rect = f"<rect width='{legend_w}' height='{legend_h}' x='{legend_x}' y='{legend_y}' style='stroke:black;stroke-width:1;fill:url(#LegendGradient)'/>"
6508 # outfile.write(f"{indent * 3}{rect}\n")
6509 # # Ticks
6510 # for variant_type in legend_ticks:
6511 # for t in legend_ticks[variant_type]:
6512 # line = f"<path d='M {t['x1']} {t['y']} H {t['x2']} V {t['y']}' style='stroke:black;stroke-width:1;fill:none'/>"
6513 # outfile.write(f"{indent * 3}{line}\n")
6514 # # Label
6515 # for g in t["label"]["glyphs"]:
6516 # glyph = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />"
6517 # outfile.write(f"{indent * 3}{glyph}\n")
6519 # Close out the plot group
6520 outfile.write(f"{indent * 1}</g>\n")
6521 outfile.write("</svg>")
6523 png_path = os.path.join(output_dir, f"{prefix}plot.png")
6524 logging.info(f"Rendering output png ({width}x{height}): {png_path}")
6525 cairosvg.svg2png(url=svg_path, write_to=png_path, output_width=width, output_height=height, scale=png_scale)
6527 return svg_path
6529def cli(args:str=None):
6531 # Parse input args from system or function input
6532 sys_argv_original = sys.argv
6533 if args != None:
6534 sys.argv = f"pangwas {args}".split(" ")
6535 command = " ".join(sys.argv)
6536 if len(sys.argv) == 1:
6537 sys.argv = ["pangwas", "--help"]
6539 # Handle sys exits by argparse more gracefully
6540 try:
6541 options = get_options()
6542 except SystemExit as exit_status:
6543 if f"{exit_status}" == "0":
6544 return 0
6545 elif f"{exit_status}" == "2":
6546 return exit_status
6547 else:
6548 msg = f"pangwas cli exited with status: {exit_status}"
6549 logging.error(msg)
6550 return exit_status
6552 # Restore original argv
6553 sys.argv = sys_argv_original
6555 if options.version:
6556 import importlib.metadata
6557 version = importlib.metadata.version("pangwas")
6558 print(f"pangwas v{version}")
6559 return 0
6561 logging.info("Begin")
6562 logging.info(f"Command: {command}")
6564 # Organize options as kwargs for functions
6565 kwargs = {k:v for k,v in vars(options).items() if k not in ["version", "subcommand"]}
6566 try:
6567 fn = globals()[options.subcommand]
6568 fn(**kwargs)
6569 except KeyError:
6570 logging.error(f"A pangwas subcommand is required (ex. pangwas extract).")
6571 return 1
6573 logging.info("Done")
6575if __name__ == "__main__":
6576 cli()