Coverage for src/pangwas/pangwas.py: 97%

3607 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-25 21:02 +0000

1#!/usr/bin/env python3 

2 

3# Base python packages 

4from collections import OrderedDict, Counter 

5import copy 

6import itertools 

7import logging 

8import math 

9import os 

10import re 

11import shutil 

12import subprocess 

13import sys 

14import textwrap 

15 

16from tqdm import tqdm 

17 

18NUCLEOTIDES = ["A", "C", "G", "T"] 

19 

20# Default arguments for programs 

21CLUSTER_ARGS = "-k 13 --min-seq-id 0.90 -c 0.90 --cluster-mode 2 --max-seqs 300" 

22DEFRAG_ARGS = "-k 13 --min-seq-id 0.90 -c 0.90 --cov-mode 1" 

23ALIGN_ARGS = "--adjustdirection --localpair --maxiterate 1000 --addfragments" 

24TREE_ARGS = "-safe -m MFP" 

25GWAS_ARGS = "--lmm" 

26 

27LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper() 

28logging.basicConfig(level=LOGLEVEL, stream=sys.stdout, format='%(asctime)s %(funcName)20s %(levelname)8s: %(message)s') 

29 

30# Notify users when they use an algorithm from ppanggolin 

31PPANGGOLIN_NOTICE = """ 

32 The defrag algorithm is adapted from PPanGGOLiN (v2.2.0) which is  

33 distributed under the CeCILL FREE SOFTWARE LICENSE AGREEMENT LABGeM. 

34 - Please cite: Gautreau G et al. (2020) PPanGGOLiN: Depicting microbial diversity via a partitioned pangenome graph. 

35 PLOS Computational Biology 16(3): e1007732. https://doi.org/10.1371/journal.pcbi.1007732 

36 - PPanGGOLiN license: https://github.com/labgem/PPanGGOLiN/blob/2.2.0/LICENSE.txt 

37 - Defrag algorithm source: https://github.com/labgem/PPanGGOLiN/blob/2.2.0/ppanggolin/cluster/cluster.py#L317 

38""" 

39 

40pangwas_description = "pangenome wide association study (panGWAS) pipeline." 

41annotate_description = "Annotate genomic assemblies with bakta." 

42extract_description = "Extract sequences and annotations from GFF files." 

43collect_description = "Collect extracted sequences from multiple samples into one file." 

44cluster_description = "Cluster nucleotide sequences with mmseqs." 

45defrag_description = "Defrag clusters by associating fragments with their parent cluster." 

46summarize_description = "Summarize clusters according to their annotations." 

47align_description = "Align clusters using mafft and create a pangenome alignment." 

48snps_description = "Extract SNPs from a pangenome alignment." 

49structural_description = "Extract structural variants from cluster alignments." 

50presence_absence_description = "Extract presence absence of clusters." 

51combine_description = "Combine variants from multiple Rtab files." 

52table_to_rtab_description = "Convert a TSV/CSV table to an Rtab file based on regex filters." 

53tree_description = "Estimate a maximum-likelihood tree with IQ-TREE." 

54root_tree_description = "Root tree on outgroup taxa." 

55binarize_description = "Convert a categorical column to multiple binary (0/1) columns." 

56vcf_to_rtab_description = "Convert a VCF file to an Rtab file." 

57gwas_description = "Run genome-wide association study (GWAS) tests with pyseer." 

58heatmap_description = "Plot a heatmap of variants alongside a tree." 

59manhattan_description = "Plot the distribution of variant p-values across the genome." 

60 

61def get_options(args:str=None): 

62 import argparse 

63 

64 sys_argv_original = sys.argv 

65 

66 if args != None: 

67 sys.argv = args.split(" ") 

68 

69 description = textwrap.dedent( 

70 f"""\ 

71 {pangwas_description} 

72 

73 ANNOTATE 

74 annotate: {annotate_description} 

75 

76 CLUSTER 

77 extract: {extract_description} 

78 collect: {collect_description} 

79 cluster: {cluster_description} 

80 defrag: {defrag_description} 

81 summarize: {summarize_description} 

82 

83 ALIGN 

84 align: {align_description} 

85 

86 VARIANTS 

87 structural: {structural_description}  

88 snps: {snps_description} 

89 presence_absence: {presence_absence_description} 

90 

91 TREE 

92 tree: {tree_description} 

93 

94 GWAS 

95 gwas: {gwas_description} 

96 

97 PLOT 

98 manhattan: {manhattan_description} 

99 heatmap: {heatmap_description} 

100 

101 UTILITY 

102 root_tree: {root_tree_description} 

103 binarize: {binarize_description} 

104 table_to_rtab: {table_to_rtab_description} 

105 vcf_to_rtab: {vcf_to_rtab_description} 

106 """) 

107 parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter) 

108 parser.add_argument('--version', help='Display program version.', action="store_true") 

109 subbcommands = parser.add_subparsers(dest="subcommand") 

110 

111 # ------------------------------------------------------------------------- 

112 # Annotate 

113 

114 description = textwrap.dedent( 

115 f"""\ 

116 {annotate_description} 

117 

118 Takes as input a FASTA file of genomic assemblies. Outputs a GFF file 

119 of annotations, among many other formats from bakta. 

120 

121 All additional arguments with be passed to the `bakta` CLI. 

122 

123 Examples: 

124 > pangwas annotate --fasta sample1.fasta --db database/bakta 

125 > pangwas annotate --fasta sample1.fasta --db database/bakta --sample sample1 --threads 2 --genus Streptococcus 

126 """) 

127 

128 annotate_parser = subbcommands.add_parser('annotate', description = description, formatter_class=argparse.RawTextHelpFormatter) 

129 

130 annotate_req_parser = annotate_parser.add_argument_group("required arguments") 

131 annotate_req_parser.add_argument("--fasta", required=True, help='Input FASTA sequences.') 

132 annotate_req_parser.add_argument("--db", required=True, help='bakta database directory.') 

133 

134 annotate_out_opt_parser = annotate_parser.add_argument_group("optional output arguments") 

135 annotate_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".") 

136 annotate_out_opt_parser.add_argument('--prefix', help='Output file prefix. If not provided, will be parsed from the fasta file name.') 

137 annotate_out_opt_parser.add_argument("--tmp", help='Temporary directory.') 

138 

139 annotate_opt_parser = annotate_parser.add_argument_group("optional arguments") 

140 annotate_opt_parser.add_argument("--sample", help='Sample identifier. If not provided, will be parsed from the fasta file name.') 

141 annotate_opt_parser.add_argument("--threads", help='CPU threads for bakta. (default: 1)', default=1) 

142 

143 # ------------------------------------------------------------------------- 

144 # Extract Sequences 

145 

146 description = textwrap.dedent( 

147 f"""\ 

148 {extract_description} 

149 

150 Takes as input a GFF annotations file. If sequences are not included, a FASTA 

151 of genomic contigs must also be provided. Both annotated and unannotated regions  

152 will be extracted. Outputs a TSV table of extracted sequence regions. 

153 

154 Examples: 

155 > pangwas extract --gff sample1.gff3 

156 > pangwas extract --gff sample1.gff3 --fasta sample1.fasta --min-len 10 

157 """) 

158 

159 extract_parser = subbcommands.add_parser('extract', description = description, formatter_class=argparse.RawTextHelpFormatter) 

160 

161 extract_req_parser = extract_parser.add_argument_group("required arguments") 

162 extract_req_parser.add_argument('--gff', required=True, help='Input GFF annotations.') 

163 

164 extract_out_opt_parser = extract_parser.add_argument_group("optional output arguments") 

165 extract_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".") 

166 extract_out_opt_parser.add_argument('--prefix', help='Output file prefix. If not provided, will be parsed from the gff file name.') 

167 

168 extract_opt_parser = extract_parser.add_argument_group("optional arguments") 

169 extract_opt_parser.add_argument('--fasta', help='Input FASTA sequences, if not provided at the end of the GFF.') 

170 extract_opt_parser.add_argument('--max-len', help='Maximum length of sequences to extract (default: 100000).', type=int, default=100000) 

171 extract_opt_parser.add_argument('--min-len', help='Minimum length of sequences to extract (default: 20).', type=int, default=20) 

172 extract_opt_parser.add_argument('--sample', help='Sample identifier to use. If not provided, is parsed from the gff file name.') 

173 extract_opt_parser.add_argument('--regex', help='Only extract gff lines that match this regular expression.') 

174 

175 # ------------------------------------------------------------------------- 

176 # Collect Sequences 

177 

178 description = textwrap.dedent( 

179 f"""\ 

180 {collect_description} 

181 

182 Takes as input multiple TSV files from extract, which can be supplied  

183 as either space separate paths, or a text file containing paths.  

184 Duplicate sequence IDs will be identified and given the suffix '.#'. 

185 Outputs concatenated FASTA and TSV files. 

186 

187 Examples: 

188 > pangwas collect --tsv sample1.tsv sample2.tsv sample3.tsv sample4.tsv 

189 > pangwas collect --tsv-paths tsv_paths.txt 

190 """) 

191 

192 collect_parser = subbcommands.add_parser('collect', description = description, formatter_class=argparse.RawTextHelpFormatter) 

193 

194 collect_req_parser = collect_parser.add_argument_group("required arguments (mutually-exclusive)") 

195 collect_input = collect_req_parser.add_mutually_exclusive_group(required=True) 

196 collect_input.add_argument('--tsv', help='TSV files from the extract subcommand.', nargs='+') 

197 collect_input.add_argument('--tsv-paths', help='TXT file containing paths to TSV files.') 

198 

199 collect_out_opt_parser = collect_parser.add_argument_group("optional output arguments") 

200 collect_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".") 

201 collect_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

202 

203 # ------------------------------------------------------------------------- 

204 # Cluster Sequences (mmseqs) 

205 

206 description = textwrap.dedent( 

207 f"""\ 

208 {cluster_description} 

209 

210 Takes as input a FASTA file of sequences for clustering. Calls MMSeqs2  

211 to cluster sequences and identify a representative sequence. Outputs a 

212 TSV table of sequence clusters and a FASTA of representative sequences. 

213 

214 Any additional arguments will be passed to `mmseqs cluster`. If no additional 

215 arguments are used, the following default args will apply: 

216 {CLUSTER_ARGS} 

217 

218 Examples: 

219 > pangwas cluster --fasta collect.fasta 

220 > pangwas cluster --fasta collect.fasta --threads 2 -k 13 --min-seq-id 0.90 -c 0.90 

221 """) 

222 

223 cluster_parser = subbcommands.add_parser('cluster', description = description, formatter_class=argparse.RawTextHelpFormatter) 

224 

225 cluster_req_parser = cluster_parser.add_argument_group("required arguments") 

226 cluster_req_parser.add_argument('-f', '--fasta', required=True, help='FASTA file of input sequences to cluster.') 

227 

228 cluster_out_opt_parser = cluster_parser.add_argument_group("optional output arguments") 

229 cluster_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".") 

230 cluster_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

231 cluster_out_opt_parser.add_argument('--tmp', help='Tmp directory (default: tmp).', default="tmp") 

232 

233 cluster_opt_parser = cluster_parser.add_argument_group("optional arguments") 

234 cluster_opt_parser.add_argument('--memory', help='Memory limit for mmseqs split. (default: 1G)', default="1G") 

235 cluster_opt_parser.add_argument('--no-clean', help="Don't clean up intermediate files.", action="store_false", dest="clean") 

236 cluster_opt_parser.add_argument('--threads', help='CPU threads for mmseqs. (default: 1)', default=1) 

237 

238 # ------------------------------------------------------------------------- 

239 # Defrag clusters 

240 

241 description = textwrap.dedent( 

242 f"""\ 

243 {defrag_description} 

244 

245 Takes as input the TSV clusters and FASTA representatives from cluster. 

246 Outputs a new cluster table and representative sequences fasta. 

247 

248 {PPANGGOLIN_NOTICE} 

249 

250 Any additional arguments will be passed to `mmseqs search`. If no additional 

251 arguments are used, the following default args will apply: 

252 {DEFRAG_ARGS} 

253 

254 Examples: 

255 > pangwas defrag --clusters clusters.tsv --representative representative.fasta --prefix defrag 

256 > pangwas defrag --clusters clusters.tsv --representative representative.fasta --prefix defrag --threads 2 -k 13 --min-seq-id 0.90 -c 0.90 

257 """) 

258 

259 defrag_parser = subbcommands.add_parser('defrag', description=description, formatter_class=argparse.RawTextHelpFormatter) 

260 defrag_req_parser = defrag_parser.add_argument_group("required arguments") 

261 defrag_req_parser.add_argument('--clusters', required=True, help='TSV file of clusters from mmseqs.') 

262 defrag_req_parser.add_argument('--representative', required=True, help='FASTA file of cluster representative sequences.') 

263 

264 defrag_out_opt_parser = defrag_parser.add_argument_group("optional output arguments") 

265 defrag_out_opt_parser.add_argument("--outdir", help='Output directory. (default: .)', default=".") 

266 defrag_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

267 defrag_out_opt_parser.add_argument('--tmp', help='Tmp directory (default: tmp).', default="tmp") 

268 

269 defrag_opt_parser = defrag_parser.add_argument_group("optional arguments") 

270 defrag_opt_parser.add_argument('--memory', help='Memory limit for mmseqs split. (default: 2G)', default="2G") 

271 defrag_opt_parser.add_argument('--no-clean', help="Don't clean up intermediate files.", action="store_false", dest="clean") 

272 defrag_opt_parser.add_argument('--threads', help='CPU threads for mmseqs. (default: 1)', default=1) 

273 

274 # ------------------------------------------------------------------------- 

275 # Summarize clusters 

276 

277 description = textwrap.dedent( 

278 f"""\ 

279 {summarize_description} 

280 

281 Takes as input the TSV table from collect, and the clusters table from  

282 either cluster or defrag. Outputs a TSV table of summarized clusters  

283 with their annotations. 

284 

285 Examples: 

286 > pangwas summarize --clusters defrag.clusters.tsv --regions regions.tsv 

287 """) 

288 

289 summarize_parser = subbcommands.add_parser('summarize', description=description, formatter_class=argparse.RawTextHelpFormatter) 

290 summarize_req_parser = summarize_parser.add_argument_group("required arguments") 

291 summarize_req_parser.add_argument('--clusters', required=True, help='TSV file of clusters from cluster or defrag.') 

292 summarize_req_parser.add_argument('--regions', required=True, help='TSV file of sequence regions from collect.') 

293 

294 summarize_opt_parser = summarize_parser.add_argument_group("optional arguments") 

295 summarize_opt_parser.add_argument("--max-product-len", help='Truncate the product description to this length if used as an identifie. (default: 50)', type=int, default=50) 

296 summarize_opt_parser.add_argument("--min-samples", help='Cluster must be observed in at least this many samples to be summarized.', type=int, default=1) 

297 summarize_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

298 summarize_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

299 summarize_opt_parser.add_argument('--threshold', help='Required this proportion of samples to have annotations in agreement. (default: 0.5)', type=float, default=0.5) 

300 

301 # ------------------------------------------------------------------------- 

302 # Align clusters 

303 

304 description = textwrap.dedent( 

305 f"""\ 

306 {align_description} 

307 

308 Takes as input the clusters from summarize and the sequence regions 

309 from collect. Outputs multiple sequence alignments per cluster  

310 as well as a pangenome alignment of concatenated clusters.\n 

311 

312 Any additional arguments will be passed to `mafft`. If no additional 

313 arguments are used, the following default args will apply: 

314 {ALIGN_ARGS} 

315 

316 Examples: 

317 > pangwas align --clusters clusters.tsv --regions regions.tsv 

318 > pangwas align --clusters clusters.tsv --regions regions.tsv --threads 2 --exclude-singletons --localpair --maxiterate 100 

319 """) 

320 

321 align_parser = subbcommands.add_parser('align', description=description, formatter_class=argparse.RawTextHelpFormatter) 

322 

323 align_req_parser = align_parser.add_argument_group("required arguments") 

324 align_req_parser.add_argument('--clusters', help='TSV of clusters from summarize.', required=True) 

325 align_req_parser.add_argument('--regions', help='TSV of sequence regions from collect.', required=True) 

326 

327 align_out_opt_parser = align_parser.add_argument_group("optional output arguments") 

328 align_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

329 align_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

330 

331 align_opt_parser = align_parser.add_argument_group("optional arguments") 

332 align_opt_parser.add_argument('--threads', help='CPU threads for running mafft in parallel. (default: 1)', type=int, default=1) 

333 align_opt_parser.add_argument('--exclude-singletons', help='Exclude clusters found in only 1 sample.', action="store_true") 

334 

335 # ------------------------------------------------------------------------- 

336 # Variants: Structural 

337 

338 description = textwrap.dedent( 

339 f"""\ 

340 {structural_description} 

341 

342 Takes as input the summarized clusters and their individual alignments. 

343 Outputs an Rtab file of structural variants. 

344 

345 Examples: 

346 > pangwas structural --clusters clusters.tsv --alignments alignments 

347 > pangwas structural --clusters clusters.tsv --alignments alignments --min-len 100 --min-indel-len 10 

348 """) 

349 

350 structural_parser = subbcommands.add_parser('structural', description=description, formatter_class=argparse.RawTextHelpFormatter) 

351 

352 structural_req_parser = structural_parser.add_argument_group("required arguments") 

353 structural_req_parser.add_argument('--clusters', required=True, help='TSV of clusters from summarize.') 

354 structural_req_parser.add_argument('--alignments', required=True, help='Directory of cluster alignments (not consensus alignments!).') 

355 

356 structural_out_opt_parser = structural_parser.add_argument_group("optional output arguments") 

357 structural_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

358 structural_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

359 

360 structural_opt_parser = structural_parser.add_argument_group("optional arguments") 

361 structural_opt_parser.add_argument('--min-len', help='Minimum variant length. (default: 10)', type=int, default=10) 

362 structural_opt_parser.add_argument('--min-indel-len', help='Minimum indel length. (default: 3)', type=int, default=3) 

363 

364 

365 # ------------------------------------------------------------------------- 

366 # Variants: SNPs 

367 

368 description = textwrap.dedent( 

369 f"""\ 

370 {snps_description} 

371 

372 Takes as input the pangenome alignment fasta, bed, and consensus file from align. 

373 Outputs an Rtab file of SNPs. 

374 

375 Examples: 

376 > pangwas snps --alignment pangenome.aln --bed pangenome.bed --consensus pangenome.consensus.fasta 

377 > pangwas snps --alignment pangenome.aln --bed pangenome.bed --consensus pangenome.consensus.fasta --structural structural.Rtab --core 0.90 --indel-window 3 --snp-window 10 

378 """) 

379 

380 snps_parser = subbcommands.add_parser('snps', description=description, formatter_class=argparse.RawTextHelpFormatter) 

381 

382 snps_req_parser = snps_parser.add_argument_group("required arguments") 

383 snps_req_parser.add_argument('--alignment', required=True, help='Fasta sequences alignment.') 

384 snps_req_parser.add_argument('--bed', required=True, help='Bed file of coordinates.') 

385 snps_req_parser.add_argument('--consensus', required=True, help='Fasta consensus/representative sequence.') 

386 

387 snps_rec_parser = snps_parser.add_argument_group("optional but recommended arguments") 

388 snps_rec_parser.add_argument('--structural', help='Rtab from the structural subcommand, used to avoid treating terminal ends as indels.') 

389 

390 snps_out_opt_parser = snps_parser.add_argument_group("optional output arguments") 

391 snps_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

392 snps_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

393 

394 snps_opt_parser = snps_parser.add_argument_group("optional arguments") 

395 snps_opt_parser.add_argument('--core', help='Core threshold for calling core SNPs. (default: 0.95)', type=float, default=0.95) 

396 snps_opt_parser.add_argument('--indel-window', help='Exclude SNPs that are within this proximity to indels. (default: 0)', type=int, default=0) 

397 snps_opt_parser.add_argument('--snp-window', help='Exclude SNPs that are within this proximity to another SNP. (default: 0)', type=int, default=0) 

398 

399 # ------------------------------------------------------------------------- 

400 # Variants: Presence Absence 

401 

402 description = textwrap.dedent( 

403 f"""\ 

404 {presence_absence_description} 

405 

406 Takes as input a TSV of summarized clusters from summarize. 

407 Outputs an Rtab file of cluster presence/absence. 

408 

409 Examples: 

410 > pangwas presence_absence --clusters clusters.tsv 

411 """) 

412 

413 pres_parser = subbcommands.add_parser('presence_absence', description=description, formatter_class=argparse.RawTextHelpFormatter) 

414 pres_req_parser = pres_parser.add_argument_group("required arguments") 

415 pres_req_parser.add_argument('--clusters', help='TSV of clusters from summarize.', required=True) 

416 

417 pres_out_opt_parser = pres_parser.add_argument_group("optional output arguments") 

418 pres_out_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

419 pres_out_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

420 

421 # ------------------------------------------------------------------------- 

422 # Variants: Combine 

423 

424 description = textwrap.dedent( 

425 f"""\ 

426 {combine_description} 

427 

428 Takes as input a list of file paths to Rtab files. Outputs an Rtab file with 

429 the variants concatenated, ensuring consistent ordering of the sample columns. 

430 

431 Examples: 

432 > pangwas combine --rtab snps.Rtab presence_absence.Rtab 

433 > pangwas combine --rtab snps.Rtab structural.Rtab presence_absence.Rtab 

434 """) 

435 combine_parser = subbcommands.add_parser('combine', description=description, formatter_class=argparse.RawTextHelpFormatter) 

436 

437 combine_req_parser = combine_parser.add_argument_group("required arguments") 

438 combine_req_parser.add_argument('--rtab', required=True, help="Rtab variants files.", nargs='+') 

439 

440 combine_opt_parser = combine_parser.add_argument_group("optional arguments") 

441 combine_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

442 combine_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

443 

444 # ------------------------------------------------------------------------- 

445 # Variants: Table to Rtab 

446 

447 description = textwrap.dedent( 

448 f"""\ 

449 {table_to_rtab_description} 

450 

451 Takes as input a TSV/CSV table to convert, and a TSV/CSV of regex filters. 

452 The filter table should have the header: column, regex, name. Where column 

453 is the 'column' to search, 'regex' is the regular expression pattern, and 

454 'name' is how the output variant should be named in the Rtab. 

455 

456 An example `filter.tsv` might look like this: 

457 

458 column regex name 

459 assembly .*sample2.* sample2 

460 lineage .*2.* lineage_2 

461 

462 Where the goal is to filter the assembly and lineage columns for particular values. 

463 

464 Examples: 

465 > pangwas table_to_rtab --table samplesheet.csv --filter filter.tsv 

466 """ 

467 ) 

468 

469 table_to_rtab_parser = subbcommands.add_parser('table_to_rtab', description=description, formatter_class=argparse.RawTextHelpFormatter) 

470 table_to_rtab_req_parser = table_to_rtab_parser.add_argument_group("required arguments") 

471 table_to_rtab_req_parser.add_argument('--table', required=True, help='TSV or CSV table.') 

472 table_to_rtab_req_parser.add_argument('--filter', required=True, help='TSV or CSV filter table.') 

473 

474 table_to_rtab_opt_parser = table_to_rtab_parser.add_argument_group("optional arguments") 

475 table_to_rtab_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

476 table_to_rtab_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

477 

478 # ------------------------------------------------------------------------- 

479 # Variants: VCF to Rtab 

480 

481 description = textwrap.dedent( 

482 f"""\ 

483 {vcf_to_rtab_description} 

484 

485 Takes as input a VCF file to convert to a SNPs Rtab file. 

486 

487 Examples: 

488 > pangwas vcf_to_rtab --vcf snps.vcf 

489 """ 

490 ) 

491 

492 vcf_to_rtab_parser = subbcommands.add_parser('vcf_to_rtab', description=description, formatter_class=argparse.RawTextHelpFormatter) 

493 vcf_to_rtab_req_parser = vcf_to_rtab_parser.add_argument_group("required arguments") 

494 vcf_to_rtab_req_parser.add_argument('--vcf', required=True, help='VCF file.') 

495 

496 vcf_to_rtab_opt_parser = vcf_to_rtab_parser.add_argument_group("optional arguments") 

497 vcf_to_rtab_opt_parser.add_argument("--bed", help='BED file with names by coordinates.') 

498 vcf_to_rtab_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

499 vcf_to_rtab_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

500 

501 # ------------------------------------------------------------------------- 

502 # Tree 

503 

504 description = textwrap.dedent( 

505 f"""\ 

506 {tree_description} 

507 

508 Takes as input a multiple sequence alignment in FASTA format. If a SNP 

509 alignment is provided, an optional text file of constant sites can be  

510 included for correction. Outputs a maximum-likelihood tree, as well as  

511 additional rooted trees if an outgroup is specified in the iqtree args. 

512 

513 Any additional arguments will be passed to `iqtree`. If no additional 

514 arguments are used, the following default args will apply: 

515 {TREE_ARGS} 

516 

517 Examples: 

518 > pangwas tree --alignment snps.core.fasta --constant-sites snps.constant_sites.txt 

519 > pangwas tree --alignment pangenome.aln --threads 4 -o sample1 --ufboot 1000 

520 """ 

521 ) 

522 tree_parser = subbcommands.add_parser('tree', description=description, formatter_class=argparse.RawTextHelpFormatter) 

523 

524 tree_req_parser = tree_parser.add_argument_group("required arguments") 

525 tree_req_parser.add_argument('--alignment', required=True, help='Multiple sequence alignment.') 

526 

527 tree_opt_parser = tree_parser.add_argument_group("optional arguments") 

528 tree_opt_parser.add_argument('--constant-sites', help='Text file containing constant sites correction for SNP alignment.') 

529 tree_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

530 tree_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

531 tree_opt_parser.add_argument('--threads', help='CPU threads for IQTREE. (default: 1)', default=1) 

532 

533 # ------------------------------------------------------------------------- 

534 # Root tree 

535 

536 description = textwrap.dedent( 

537 f"""\ 

538 {root_tree_description} 

539 

540 Takes as input a path to a phylogenetic tree and outgroup taxa. This is 

541 a utility script that is meant to fix IQ-TREE's creation of a multifurcated 

542 root node. It will position the root node along the midpoint of the branch 

543 between the outgroup taxa and all other samples. If no outgroup is selected, 

544 the tree will be rooted using the first taxa. Outputs a new tree in the specified 

545 tree format. 

546 

547 Note: This functionality is already included in the tree subcommand. 

548 

549 Examples: 

550 > pangwas root_tree --tree tree.treefile 

551 > pangwas root_tree --tree tree.treefile --outgroup sample1 

552 > pangwas root_tree --tree tree.treefile --outgroup sample1,sample4 

553 > pangwas root_tree --tree tree.nex --outgroup sample1 --tree-format nexus 

554 """ 

555 ) 

556 root_tree_parser = subbcommands.add_parser('root_tree', description=description, formatter_class=argparse.RawTextHelpFormatter) 

557 

558 root_tree_req_parser = root_tree_parser.add_argument_group("required arguments") 

559 root_tree_req_parser.add_argument('--tree', help='Path to phylogenetic tree.') 

560 

561 root_tree_opt_parser = root_tree_parser.add_argument_group("optional arguments") 

562 root_tree_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

563 root_tree_opt_parser.add_argument("--outgroup", help='Outgroup taxa as CSV string. If not specified, roots on first taxon.') 

564 root_tree_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

565 root_tree_opt_parser.add_argument('--tree-format', help='Tree format. (default: newick)', type=str, default="newick") 

566 

567 # ------------------------------------------------------------------------- 

568 # GWAS: Binarize 

569 

570 description = textwrap.dedent( 

571 f"""\ 

572 {binarize_description} 

573 

574 Takes as input a table (TSV or CSV) and the name of a column to binarize. 

575 Outputs a new table with separate columns for each categorical value. 

576 

577 Any additional arguments will be passed to `pyseer`. 

578 

579 Examples: 

580 pangwas binarize --table samplesheet.csv --column lineage --output lineage.binarize.csv 

581 pangwas binarize --table samplesheet.tsv --column outcome --output outcome.binarize.tsv 

582 """) 

583 

584 binarize_parser = subbcommands.add_parser('binarize', description=description, formatter_class=argparse.RawTextHelpFormatter) 

585 binarize_req_parser = binarize_parser.add_argument_group("required arguments") 

586 binarize_req_parser.add_argument('--table', required=True, help='TSV or CSV table.') 

587 binarize_req_parser.add_argument('--column', required=True, help='Column name to binarize.') 

588 

589 binarize_opt_parser = binarize_parser.add_argument_group("optional arguments") 

590 binarize_opt_parser.add_argument('--column-prefix', help='Prefix to add to column names.') 

591 binarize_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

592 binarize_opt_parser.add_argument("--output-delim", help='Output delimiter. (default: \t )', default="\t") 

593 binarize_opt_parser.add_argument('--prefix', help='Prefix for output files') 

594 binarize_opt_parser.add_argument('--transpose', help='Tranpose output table.', action='store_true') 

595 

596 # ------------------------------------------------------------------------- 

597 # GWAS: pyseer 

598 

599 description = textwrap.dedent( 

600 f"""\ 

601 {gwas_description} 

602 

603 Takes as input the TSV file of summarized clusters from summarize, an Rtab file of variants, 

604 a TSV/CSV table of phenotypes, and a column name representing the trait of interest. 

605 Outputs tables of locus effects, and optionally lineage effects (bugwas) if specified. 

606 

607 Any additional arguments will be passed to `pyseer`. If no additional 

608 arguments are used, the following default args will apply: 

609 {GWAS_ARGS} 

610 

611 Examples: 

612 > pangwas gwas --variants combine.Rtab --table samplesheet.csv --column lineage --no-distances 

613 > pangwas gwas --variants combine.Rtab --table samplesheet.csv --column resistant --lmm --tree tree.rooted.nwk --clusters clusters.tsv --lineage-column lineage ' 

614 """) 

615 gwas_parser = subbcommands.add_parser('gwas', description=description, formatter_class=argparse.RawTextHelpFormatter) 

616 # Required options 

617 gwas_req_parser = gwas_parser.add_argument_group("required arguments") 

618 gwas_req_parser.add_argument('--variants', required=True, help='Rtab file of variants.') 

619 gwas_req_parser.add_argument('--table', required=True, help='TSV or CSV table of phenotypes (variables).') 

620 gwas_req_parser.add_argument('--column', required=True, help='Column name of trait (variable) in table.') 

621 # Lineage effects options 

622 gwas_lin_parser = gwas_parser.add_argument_group("optional arguments (bugwas)") 

623 gwas_lin_parser.add_argument('--lineage-column', help='Column name of lineages in table. Enables bugwas lineage effects.') 

624 # LMM options 

625 gwas_lmm_parser = gwas_parser.add_argument_group("optional arguments (lmm)") 

626 gwas_lmm_parser.add_argument('--no-midpoint', help='Disable midpoint rooting of the tree for the kinship matrix.', dest="midpoint", action="store_false") 

627 gwas_lmm_parser.add_argument('--tree', help='Newick phylogenetic tree. Not required if pyseer args includes --no-distances.') 

628 # Data Options 

629 gwas_data_parser = gwas_parser.add_argument_group("optional arguments (data)") 

630 gwas_data_parser.add_argument('--continuous', help='Treat the column trait as a continuous variable.', action="store_true") 

631 gwas_data_parser.add_argument('--exclude-missing', help='Exclude samples missing phenotype data.', action="store_true") 

632 

633 gwas_opt_parser = gwas_parser.add_argument_group("optional arguments") 

634 gwas_opt_parser.add_argument('--clusters', help='TSV of clusters from summarize.') 

635 gwas_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

636 gwas_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

637 gwas_opt_parser.add_argument('--threads', help='CPU threads for pyseer. (default: 1)', default=1) 

638 

639 # ------------------------------------------------------------------------- 

640 # Plot 

641 

642 description = textwrap.dedent( 

643 f"""\ 

644 {heatmap_description} 

645 

646 Takes as input a table of variants and/or a newick tree. The table can be either 

647 an Rtab file, or the locus effects TSV output from the gwas subcommand. 

648 If both a tree and a table are provided, the tree will determine the sample order  

649 and arrangement. If just a table is provided, sample order will follow the  

650 order of the sample columns. A TXT of focal sample IDs can also be supplied 

651 with one sample ID per line. Outputs a plot in SVG and PNG format. 

652 

653 Examples: 

654 > pangwas heatmap --tree tree.rooted.nwk 

655 > pangwas heatmap --rtab combine.Rtab 

656 > pangwas heatmap --gwas resistant.locus_effects.significant.tsv 

657 > pangwas heatmap --tree tree.rooted.nwk --rtab combine.Rtab --focal focal.txt 

658 > pangwas heatmap --tree tree.rooted.nwk --gwas resistant.locus_effects.significant.tsv 

659 > pangwas heatmap --tree tree.rooted.nwk --tree-width 500 --png-scale 2.0 

660 """ 

661 ) 

662 

663 heatmap_parser = subbcommands.add_parser('heatmap', description=description, formatter_class=argparse.RawTextHelpFormatter) 

664 

665 heatmap_req_parser = heatmap_parser.add_argument_group("optional variant arguments (mutually-exclusive)") 

666 heatmap_variants_input = heatmap_req_parser.add_mutually_exclusive_group(required=False) 

667 heatmap_variants_input.add_argument('--gwas', help='TSV table of variants from gwas subcommand.') 

668 heatmap_variants_input.add_argument('--rtab', help='Rtab table of variants.') 

669 

670 heatmap_tree_parser = heatmap_parser.add_argument_group("optional tree arguments") 

671 heatmap_tree_parser.add_argument('--tree', help='Tree file.') 

672 heatmap_tree_parser.add_argument('--tree-format', help='Tree format. (default: newick)', type=str, default="newick") 

673 heatmap_tree_parser.add_argument('--tree-width', help='Width of the tree in pixels. (default: 200)', type=int, default=200) 

674 heatmap_tree_parser.add_argument('--root-branch', help='Root branch length in pixels. (default: 10)', type=int, default=10) 

675 

676 heatmap_opt_parser = heatmap_parser.add_argument_group("optional arguments") 

677 heatmap_opt_parser.add_argument('--focal', help='TXT file of focal samples.') 

678 heatmap_opt_parser.add_argument('--font-size', help='Font size of the tree labels. (default: 16)', type=int, default=16) 

679 heatmap_opt_parser.add_argument('--font-family', help='Font family of the tree labels. (default: Roboto)', type=str, default="Roboto") 

680 heatmap_opt_parser.add_argument('--heatmap-scale', help='Scaling factor of heatmap boxes relative to the text. (default: 1.5)', type=float, default=1.5) 

681 heatmap_opt_parser.add_argument('--margin', help='Margin sizes in pixels. (default: 20)', type=int, default=20) 

682 heatmap_opt_parser.add_argument('--min-score', help='Filter GWAS variants by a minimum score (range: -1.0 to 1.0).', type=float) 

683 heatmap_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

684 heatmap_opt_parser.add_argument('--png-scale', help='Scaling factor of the PNG file. (default: 2.0)', type=float, default=2.0) 

685 heatmap_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

686 heatmap_opt_parser.add_argument('--tip-pad', help='Tip padding. (default: 10)', type=int, default=10) 

687 

688 # ------------------------------------------------------------------------- 

689 # Manhattan 

690 

691 description = textwrap.dedent( 

692 f"""\ 

693 {manhattan_description} 

694 

695 Takes as input a table of locus effects from the subcommand and a bed 

696 file such as the one producted by the align subcommand. Outputs a  

697 manhattan plot in SVG and PNG format. 

698 

699 Examples: 

700 > pangwas manhattan --gwas locus_effects.tsv --bed pangenome.bed 

701 > pangwas manhattan --gwas locus_effects.tsv --bed pangenome.bed --syntenies chromosome --clusters pbpX --variant-types snp presence_absence 

702 """ 

703 ) 

704 

705 manhattan_parser = subbcommands.add_parser('manhattan', description=description, formatter_class=argparse.RawTextHelpFormatter) 

706 

707 manhattan_req_parser = manhattan_parser.add_argument_group("required arguments") 

708 manhattan_req_parser.add_argument('--bed', required=True, help='BED file with region coordinates and names.') 

709 manhattan_req_parser.add_argument('--gwas', required=True, help='TSV table of variants from gwas subcommand.') 

710 

711 manhattan_opt_parser = manhattan_parser.add_argument_group("optional filter arguments") 

712 manhattan_opt_parser.add_argument('--clusters', help='Only plot these clusters. (default: all)', nargs='+', default=["all"]) 

713 manhattan_opt_parser.add_argument('--syntenies', help='Only plot these synteny blocks. (default: all)', nargs='+', default=["all"]) 

714 manhattan_opt_parser.add_argument('--variant-types', help='Only plot these variant types. (default: all)', nargs='+', default=["all"]) 

715 

716 manhattan_opt_parser = manhattan_parser.add_argument_group("optional arguments") 

717 manhattan_opt_parser.add_argument('--font-size', help='Font size of the tree labels. (default: 16)', type=int, default=16) 

718 manhattan_opt_parser.add_argument('--font-family', help='Font family of the tree labels. (default: Roboto)', type=str, default="Roboto") 

719 manhattan_opt_parser.add_argument('--height', help='Plot height in pixels.', type=int, default=500) 

720 manhattan_opt_parser.add_argument('--margin', help='Margin sizes in pixels. (default: 20)', type=int, default=20) 

721 manhattan_opt_parser.add_argument('--max-blocks', help='Maximum number of blocks to draw before switching to pangenome coordinates. (default: 20)', type=int, default=20) 

722 manhattan_opt_parser.add_argument("--outdir", help='Output directory. (default: . )', default=".") 

723 manhattan_opt_parser.add_argument('--png-scale', help='Scaling factor of the PNG file. (default: 2.0)', type=float, default=2.0) 

724 manhattan_opt_parser.add_argument('--prefix', help='Prefix for output files.') 

725 manhattan_opt_parser.add_argument('--prop-x-axis', help='Make x-axis proporional to genomic length.', action="store_true") 

726 manhattan_opt_parser.add_argument('--width', help='Plot width in pixels.', type=int, default=1000) 

727 manhattan_opt_parser.add_argument('--ymax', help='A -log10 value to use as the y axis max, to synchronize multiple plots.', type=float) 

728 

729 # ------------------------------------------------------------------------- 

730 # Finalize 

731 

732 defined, undefined = parser.parse_known_args() 

733 if len(undefined) > 0: 

734 defined.args = " ".join(undefined) 

735 logging.warning(f"Using undefined args: {defined.args}") 

736 

737 sys.argv = sys_argv_original 

738 return defined 

739 

740 

741def check_output_dir(output_dir): 

742 if output_dir != None and output_dir != "." and output_dir != "" and not os.path.exists(output_dir): 

743 logging.info(f"Creating output directory: {output_dir}") 

744 os.makedirs(output_dir) 

745 

746 

747def reverse_complement(sequence): 

748 lookup = {"A": "T", "C": "G", "G": "C", "T": "A"} 

749 split = [] 

750 for nuc in sequence[::-1]: 

751 nuc = nuc.upper() 

752 compl = lookup[nuc] if nuc in lookup else nuc 

753 split.append(compl) 

754 return "".join(split) 

755 

756 

757def extract_cli_param(args: str, target: str): 

758 """ 

759 Extract a value flag from a string of CLI arguments. 

760 

761 :param args: String of CLI arguments (ex. 'iqtree -T 3 --seed 123'). 

762 :param args: String of target param (ex. '-T', '--seed') 

763 """ 

764 if not target.endswith(" "): 

765 target += " " 

766 val = None 

767 if target in args: 

768 start = args.index(target) + len(target) 

769 try: 

770 end = args[start:].index(' ') 

771 val = args[start:][:end] 

772 except ValueError: 

773 val = args[start:] 

774 return val 

775 

776 

777def run_cmd(cmd: str, output=None, err=None, quiet=False, display_cmd=True): 

778 """ 

779 Run a subprocess command. 

780 """ 

781 if display_cmd: 

782 logging.info(f"{cmd}") 

783 cmd_run = [str(c) for c in cmd.replace("\n", " ").split(" ") if c != ""] 

784 try: 

785 result = subprocess.run(cmd_run, check=True, capture_output=True, text=True) 

786 except subprocess.CalledProcessError as error: 

787 logging.error(f"stdout:\n{error.stdout}") 

788 logging.error(f"stderr:\n{error.stderr}") 

789 raise Exception(f"Error in command: {cmd}\n" + error.stderr) 

790 else: 

791 if not quiet: 

792 logging.info(result.stdout) 

793 if output: 

794 with open(output, 'w') as outfile: 

795 outfile.write(result.stdout) 

796 if err: 

797 with open(err, 'w') as errfile: 

798 errfile.write(result.stderr) 

799 return (result.stdout, result.stderr) 

800 

801def annotate( 

802 fasta: str, 

803 db: str, 

804 outdir: str = ".", 

805 prefix: str = None, 

806 sample: str = None, 

807 threads: int = 1, 

808 tmp: str = None, 

809 args: str = None 

810 ): 

811 """ 

812 Annotate genomic assemblies with bakta. 

813 

814 Takes as input a FASTA file of genomic assemblies. Outputs a GFF file 

815 of annotations, among many other formats from bakta. 

816 

817 Any additional arguments in `args` will be passed to `bakta`. 

818 

819 >>> annotate(fasta="sample1.fasta", db="database/bakta") 

820 >>> annotate(fasta="sample2.fasta", db="database/bakta", threads=2, args="--genus Streptococcus") 

821 

822 :param fasta: File path to the input FASTA file. 

823 :param db: Directory path of the bakta database. 

824 :param outdir: Output directory. 

825 :param prefix: Prefix for output files.  

826 :param sample: Sample identifier for output files. 

827 :param threads: CPU threads for bakta. 

828 :param tmp: Temporary directory. 

829 :param args: Str of additional arguments to pass to bakta 

830 

831 :return: Path to the output GFF annotations. 

832 """ 

833 from bakta.main import main as bakta 

834 

835 output_dir, tmp_dir = outdir, tmp 

836 

837 # Argument Checking 

838 args = args if args != None else "" 

839 sample = sample if sample else os.path.basename(fasta).split(".")[0] 

840 logging.info(f"Using sample identifier: {sample}") 

841 prefix = prefix if prefix != None else os.path.basename(fasta).split(".")[0] 

842 logging.info(f"Using prefix: {prefix}") 

843 if tmp != None: 

844 args += f" --tmp-dir {tmp}" 

845 

846 # bakta needs the temporary directory to already exist 

847 # bakta explicity wants the output directory to NOT exist 

848 # so that is why we don't create it 

849 check_output_dir(tmp_dir) 

850 

851 # Set up cmd to execute, we're not going to use the run_cmd, because 

852 # bakta is a python package, and we call it directly! This also has the  

853 # benefit of better/immediate logging. 

854 cmd = f"bakta --force --prefix {prefix} --db {db} --threads {threads} --output {output_dir} {args} {fasta}" 

855 logging.info(f"{cmd}") 

856 

857 # Split cmd string and handle problematic whitespace 

858 cmd_run = [str(c) for c in cmd.replace("\n", " ").split(" ") if c != ""] 

859 # Manually set sys.argv for bakta 

860 sys_argv_original = sys.argv 

861 sys.argv = cmd_run 

862 bakta() 

863 

864 # Fix bad quotations 

865 for ext in ["embl", "faa", "ffn", "gbff", "gff3", "tsv"]: 

866 file_path = os.path.join(output_dir, f"{prefix}.{ext}") 

867 logging.info(f"Fixing bad quotations in file: {file_path}") 

868 with open(file_path) as infile: 

869 content = infile.read() 

870 with open(file_path, 'w') as outfile: 

871 outfile.write(content.replace('‘', '').replace('’', '')) 

872 

873 # Restore original sys.argv 

874 sys.argv = sys_argv_original 

875 

876 gff_path = os.path.join(output_dir, f"{sample}.gff3") 

877 return gff_path 

878 

879 

880def extract( 

881 gff: str, 

882 sample: str = None, 

883 fasta: str = None, 

884 outdir: str = ".", 

885 prefix: str = None, 

886 min_len: int = 20, 

887 max_len: int = 100000, 

888 regex: str = None, 

889 args: str = None 

890 ) -> OrderedDict: 

891 """ 

892 Extract sequence records from GFF annotations. 

893 

894 Takes as input a GFF annotations file. If sequences are not included, a FASTA 

895 of genomic contigs must also be provided. Both annotated and unannotated regions  

896 will be extracted. Outputs a TSV table of extracted sequence regions. 

897 

898 >>> extract(gff='sample1.gff3') 

899 >>> extract(gff='sample2.gff3', fasta="sample2.fasta", min_len=10) 

900 

901 :param gff: File path to the input GFF file. 

902 :param sample: A sample identifier that will be used in output files. 

903 :param fasta: File path to the input FASTA file (if GFF doesn't have sequences). 

904 :param outdir: Output directory. 

905 :param prefix: Output file path prefix. 

906 :param min_len: The minimum length of sequences/annotations to extract. 

907 :param regex: Only extract lines that match this regex. 

908 :param args: Str of additional arguments [not implemented] 

909 

910 :return: Sequence records and annotations as an OrderedDict. 

911 """ 

912 

913 from Bio import SeqIO 

914 from io import StringIO 

915 import re 

916 

917 gff_path, fasta_path, output_dir = gff, fasta, outdir 

918 

919 # Argument checking 

920 args = args if args != None else "" 

921 sample = sample if sample != None else os.path.basename(gff).split(".")[0] 

922 logging.info(f"Using sample identifier: {sample}") 

923 prefix = f"{prefix}" if prefix != None else os.path.basename(gff).split(".")[0] 

924 logging.info(f"Using prefix: {prefix}") 

925 if "/" in prefix: 

926 msg = "Prefix cannot contain slashes (/)" 

927 logging.error(msg) 

928 raise Exception(msg) 

929 min_len = int(min_len) 

930 max_len = int(max_len) if max_len != None else None 

931 

932 # Check output directory 

933 check_output_dir(output_dir) 

934 

935 # A visual separator for creating locus names for unannotated regions 

936 # ex. 'SAMPL01_00035___SAMPL01_00040' indicating the region 

937 # falls between two annotated regions: SAMPL01_00035 and SAMPL01_00040 

938 delim = "__" 

939 

940 # ------------------------------------------------------------------------- 

941 # Extract contig sequences 

942 

943 logging.info(f"Reading GFF: {gff_path}") 

944 

945 contigs = OrderedDict() 

946 with open(gff) as infile: 

947 gff = infile.read() 

948 if "##FASTA" in gff: 

949 logging.info(f"Extracting sequences from GFF.") 

950 fasta_i = gff.index("##FASTA") 

951 fasta_string = gff[fasta_i + len("##FASTA"):].strip() 

952 fasta_io = StringIO(fasta_string) 

953 records = SeqIO.parse(fasta_io, "fasta") 

954 elif fasta_path == None: 

955 msg = f"A FASTA file must be provided if no sequences are in the GFF: {gff_path}" 

956 logging.error(msg) 

957 raise Exception(msg) 

958 

959 if fasta_path != None: 

960 logging.info(f"Extracting sequences from FASTA: {fasta_path}") 

961 if "##FASTA" in gff: 

962 logging.warning(f"Sequences found in GFF, fasta will be ignored: {fasta}") 

963 else: 

964 records = SeqIO.parse(fasta_path, "fasta") 

965 fasta_i = None 

966 

967 sequences_seen = set() 

968 for record in records: 

969 if record.id in sequences_seen: 

970 msg = f"Duplicate sequence ID found: {record.id}" 

971 logging.error(msg) 

972 raise Exception(msg) 

973 sequences_seen.add(record.id) 

974 contigs[record.id] = record.seq 

975 

976 # ------------------------------------------------------------------------- 

977 # Extract annotations 

978 

979 logging.info(f"Extracting annotations from gff: {gff_path}") 

980 

981 # Parsing GFF is not yet part of biopython, so this is done manually: 

982 # https://biopython.org/wiki/GFF_Parsing 

983 

984 annotations = OrderedDict() 

985 contig = None 

986 sequence = "" 

987 

988 # Keep track of loci, in case we need to flag duplicates 

989 locus_counts = {} 

990 gene_counts = {} 

991 locus_to_contig = {} 

992 

993 comment_contigs = set() 

994 

995 for i,line in enumerate(gff.split("\n")): 

996 # If we hit the fasta, we've seen all annotations 

997 if line.startswith("##FASTA"): break 

998 if line.startswith("##sequence-region"): 

999 _comment, contig, start, end = line.split(" ") 

1000 start, end = int(start), int(end) 

1001 if contig not in annotations: 

1002 annotations[contig] = {"start": start, "end": end, "length": 1 + end - start, "loci": OrderedDict()} 

1003 comment_contigs.add(contig) 

1004 continue 

1005 # Skip over all other comments 

1006 if line.startswith("#"): continue 

1007 # Skip over empty lines 

1008 if line == "": continue 

1009 

1010 # Parse standardized annotation fields 

1011 line = line.replace("\"", "") 

1012 line_split = line.split("\t") 

1013 if len(line_split) < 9: 

1014 msg = f"GFF record does not contain 9 fields: {line}" 

1015 logging.error(msg) 

1016 raise Exception(msg) 

1017 contig, _source, feature, start, end, _score, strand, _frame = [line_split[i] for i in range(0,8)] 

1018 start, end = int(start), int(end) 

1019 if feature == "region" or feature == "databank_entry": 

1020 if contig in comment_contigs: 

1021 continue 

1022 elif contig in annotations: 

1023 msg = f"Duplicate contig ID found: {contig}" 

1024 logging.error(msg) 

1025 raise Exception(msg) 

1026 annotations[contig] = {"start": start, "end": end, "length": 1 + end - start, "loci": OrderedDict()} 

1027 continue 

1028 elif regex != None and not re.match(regex, line, re.IGNORECASE): 

1029 continue 

1030 

1031 attributes = line_split[8] 

1032 attributes_dict = {a.split("=")[0].replace(" ", ""):a.split("=")[1] for a in attributes.split(";") if "=" in a} 

1033 locus = f"{attributes_dict['ID']}" 

1034 if "gene" in attributes_dict: 

1035 gene = attributes_dict["gene"] 

1036 if gene not in gene_counts: 

1037 gene_counts[gene] = {"count": 0, "current": 1} 

1038 gene_counts[gene]["count"] += 1 

1039 

1040 # Add sample prefix to help duplicate IDs later 

1041 if not locus.startswith(sample): 

1042 locus = f"{sample}_{locus}" 

1043 

1044 # Check for duplicate loci, how does this happen in NCBI annotations? 

1045 if locus not in locus_counts: 

1046 locus_counts[locus] = 1 

1047 else: 

1048 locus_counts[locus] += 1 

1049 dup_i = locus_counts[locus] 

1050 locus = f"{locus}.{dup_i}" 

1051 logging.debug(f"Duplicate locus ID found, flagging as: {locus}") 

1052 

1053 if contig not in contigs: 

1054 msg = f"Contig {contig} in {sample} annotations is not present in the sequences." 

1055 logging.error(msg) 

1056 raise Exception(msg) 

1057 sequence = contigs[contig][start - 1:end] 

1058 if strand == "-": 

1059 sequence = reverse_complement(sequence) 

1060 

1061 data = OrderedDict({ 

1062 "sample" : sample, 

1063 "contig" : contig, 

1064 "locus" : locus, 

1065 "feature" : feature, 

1066 "start" : start, 

1067 "end" : end, 

1068 "length" : 1 + end - start, 

1069 "strand" : strand, 

1070 "upstream" : "", 

1071 "downstream" : "", 

1072 "attributes" : attributes, 

1073 "sequence_id" : f"{locus}", 

1074 "sequence" : sequence, 

1075 }) 

1076 logging.debug(f"\tcontig={contig}, locus={locus}") 

1077 annotations[contig]["loci"][locus] = data 

1078 locus_to_contig[locus] = contig 

1079 

1080 # Check for duplicates, rename the first duplicate loci 

1081 logging.info(f"Checking for duplicate locus IDs.") 

1082 for locus,count in locus_counts.items(): 

1083 if count <= 1: continue 

1084 contig = locus_to_contig[locus] 

1085 data = annotations[contig]["loci"][locus] 

1086 new_locus = f"{locus}.1" 

1087 logging.debug(f"Duplicate locus ID found, flagging as: {new_locus}") 

1088 data["locus"], data["sequence_id"] = new_locus, new_locus 

1089 annotations[contig]["loci"] = OrderedDict({ 

1090 new_locus if k == locus else k:v 

1091 for k,v in annotations[contig]["loci"].items() 

1092 }) 

1093 

1094 # ------------------------------------------------------------------------- 

1095 # Extract unannotated regions 

1096 

1097 if regex != None: 

1098 logging.info("Skipping extraction of unannotate regions due to regex.") 

1099 else: 

1100 logging.info(f"Extracting unannotated regions.") 

1101 for contig,contig_data in annotations.items(): 

1102 if contig not in contigs: 

1103 msg = f"Contig {contig} in {sample} annotations is not present in the sequences." 

1104 logging.error(msg) 

1105 raise Exception(msg) 

1106 contig_sequence = contigs[contig] 

1107 num_annotations = len([k for k,v in contig_data["loci"].items() if v["feature"] != "region"]) 

1108 contig_start = contig_data["start"] 

1109 contig_end = contig_data["end"] 

1110 

1111 # Contig minimum length check 

1112 contig_len = (1 + contig_end - contig_start) 

1113 if contig_len < min_len: continue 

1114 

1115 logging.debug(f"\tcontig={contig}") 

1116 

1117 # If there were no annotations on the contig, extract entire contig as sequence 

1118 if num_annotations == 0: 

1119 logging.debug(f"\t\tlocus={contig}") 

1120 locus = f"{contig}" 

1121 if not locus.startswith(sample): 

1122 locus = f"{sample}_{locus}" 

1123 l_data = { 

1124 "sample" : sample, 

1125 "contig" : contig, 

1126 "locus" : locus, 

1127 "feature" : "unannotated", 

1128 "start" : contig_start, 

1129 "end" : contig_end, 

1130 "length" : 1 + contig_end - contig_start, 

1131 "strand" : "+", 

1132 "upstream" : f"{contig}_TERMINAL", 

1133 "downstream" : f"{contig}_TERMINAL", 

1134 "attributes" : "", 

1135 "sequence_id" : locus, 

1136 "sequence" : contig_sequence, 

1137 } 

1138 annotations[contig]["loci"][contig] = l_data 

1139 logging.debug(f"\t\t\tfull contig unannotated, locus={locus}") 

1140 continue 

1141 

1142 contig_annotations = list(contig_data["loci"].keys()) 

1143 

1144 for i,locus in enumerate(contig_annotations): 

1145 logging.debug(f"\t\tlocus={locus}") 

1146 l_data = contig_data["loci"][locus] 

1147 l_start, l_end = l_data["start"], l_data["end"] 

1148 

1149 # Make a template for an unannotated region 

1150 l_data = { 

1151 "sample" : sample, 

1152 "contig" : contig, 

1153 "locus" : None, 

1154 "feature" : "unannotated", 

1155 "start" : None, 

1156 "end" : None, 

1157 "length" : 0, 

1158 "strand" : "+", 

1159 "upstream" : "", 

1160 "downstream" : "", 

1161 "attributes" : "", 

1162 "sequence_id" : None, 

1163 "sequence" : None, 

1164 } 

1165 # Find the inter-genic regions and upstream/downstream loci 

1166 start, end, sequence, upstream, downstream = None, None, None, None, None 

1167 

1168 # Case 1. Unannotated at the start of the contig 

1169 if i == 0 and l_start != contig_start: 

1170 start = 1 

1171 end = l_start - 1 

1172 length = 1 + end - start 

1173 if length >= min_len and (max_len == None or length <= max_len): 

1174 upstream = f"{contig}_TERMINAL" 

1175 downstream = locus 

1176 # Base strand on the downstream loci 

1177 strand = annotations[contig]["loci"][downstream]["strand"] 

1178 sequence_id = f"{upstream}{delim}{downstream}" 

1179 sequence = contig_sequence[start - 1:end] 

1180 # If the downstream is reverse, treat unannotated as reversed 

1181 if strand == "-": 

1182 sequence = reverse_complement(sequence) 

1183 l_data_start = copy.deepcopy(l_data) 

1184 for k,v in zip(["start", "end", "length", "sequence_id", "sequence", "locus", "strand"], [start, end, length, sequence_id, sequence, sequence_id, strand]): 

1185 l_data_start[k] = v 

1186 l_data_start["upstream"], l_data_start["downstream"] = upstream, downstream 

1187 logging.debug(f"\t\t\tunannotated at start, contig={contig}, sequence_id={sequence_id}, start: {start}, end: {end}") 

1188 annotations[contig]["loci"][sequence_id] = l_data_start 

1189 

1190 # Update upstream for annotation 

1191 annotations[contig]["loci"][locus]["upstream"] = sequence_id 

1192 

1193 # Case 2. Unannotated at the end of the contig 

1194 if i == (num_annotations - 1) and l_end != contig_end: 

1195 start = l_end + 1 

1196 end = contig_end 

1197 length = 1 + end - start 

1198 if length >= min_len and (max_len == None or length <= max_len): 

1199 upstream = locus 

1200 # base strand on the upstream loci 

1201 strand = annotations[contig]["loci"][upstream]["strand"] 

1202 downstream = f"{contig}_TERMINAL" 

1203 sequence_id = f"{upstream}{delim}{downstream}" 

1204 sequence = contig_sequence[start - 1:end] 

1205 # If the upstream is reversed, treat unannotated as reversed 

1206 if strand == "-": 

1207 sequence = reverse_complement(sequence) 

1208 l_data_end = copy.deepcopy(l_data) 

1209 for k,v in zip(["start", "end", "length", "sequence_id", "sequence", "locus", "strand"], [start, end, length, sequence_id, sequence, sequence_id, strand]): 

1210 l_data_end[k] = v 

1211 l_data_end["upstream"], l_data_end["downstream"] = upstream, downstream 

1212 logging.debug(f"\t\t\tunannotated at end, contig={contig}, sequence_id={sequence_id}, start: {start}, end: {end}") 

1213 annotations[contig]["loci"][sequence_id] = l_data_end 

1214 

1215 # Update downstream for annotation 

1216 annotations[contig]["loci"][locus]["downstream"] = sequence_id 

1217 

1218 # Case 3. Unannotated in between annotations 

1219 if num_annotations > 1 and i != (num_annotations - 1): 

1220 

1221 upstream = locus 

1222 downstream = contig_annotations[i+1] 

1223 start = l_end + 1 

1224 end = contig_data["loci"][downstream]["start"] - 1 

1225 length = 1 + end - start 

1226 # Set upstream downstream based on order in GFF 

1227 upstream_strand = annotations[contig]["loci"][upstream]["strand"] 

1228 downstream_strand = annotations[contig]["loci"][downstream]["strand"] 

1229 

1230 # Check that the region is long enough 

1231 if length >= min_len and (max_len == None or length <= max_len): 

1232 sequence_id = f"{upstream}{delim}{downstream}" 

1233 sequence = contig_sequence[start - 1:end] 

1234 # Identify strand and orientation  

1235 if upstream_strand == "-" and downstream_strand == "-": 

1236 strand = "-" 

1237 sequence = reverse_complement(sequence) 

1238 else: 

1239 strand = "+" 

1240 l_data_middle = copy.deepcopy(l_data) 

1241 for k,v in zip(["start", "end", "length", "sequence_id", "sequence", "locus", "strand"], [start, end, length, sequence_id, sequence, sequence_id, strand]): 

1242 l_data_middle[k] = v 

1243 l_data_middle["upstream"], l_data_middle["downstream"] = upstream, downstream 

1244 logging.debug(f"\t\t\tunannotated in middle, contig={contig}, sequence_id={sequence_id}, start: {start}, end: {end}") 

1245 annotations[contig]["loci"][sequence_id] = l_data_middle 

1246 

1247 # ------------------------------------------------------------------------- 

1248 # Order and Filter 

1249 

1250 logging.info(f"Ordering records by contig and coordinate.") 

1251 contig_order = list(annotations.keys()) 

1252 for contig in contig_order: 

1253 loci = annotations[contig]["loci"] 

1254 # sort by start and end position 

1255 loci_sorted = OrderedDict(sorted(loci.items(), key=lambda item: [item[1]["start"], item[1]["end"]]) ) 

1256 annotations[contig]["loci"] = OrderedDict() 

1257 for locus,locus_data in loci_sorted.items(): 

1258 # final checks, minimum length and not entirely ambiguous characters 

1259 non_ambig = list(set([n for n in locus_data["sequence"] if n in NUCLEOTIDES])) 

1260 length = locus_data["length"] 

1261 if (length >= min_len) and (len(non_ambig) > 0): 

1262 if max_len == None or length <= max_len: 

1263 annotations[contig]["loci"][locus] = locus_data 

1264 

1265 

1266 # ------------------------------------------------------------------------- 

1267 # Upstream/downtream loci 

1268 

1269 for contig in annotations: 

1270 contig_data = annotations[contig] 

1271 c_start, c_end = contig_data["start"], contig_data["end"] 

1272 loci = list(contig_data["loci"].keys()) 

1273 for i,locus in enumerate(contig_data["loci"]): 

1274 upstream, downstream = "", "" 

1275 locus_data = contig_data["loci"][locus] 

1276 l_start, l_end = locus_data["start"], locus_data["end"] 

1277 # Annotation at very start 

1278 if i > 0: 

1279 upstream = loci[i-1] 

1280 elif l_start == c_start: 

1281 upstream = f"{contig}_TERMINAL" 

1282 # Annotation at very end 

1283 if i < (len(loci) - 1): 

1284 downstream = loci[i+1] 

1285 elif l_end == c_end: 

1286 downstream = f"{contig}_TERMINAL" 

1287 annotations[contig]["loci"][locus]["upstream"] = upstream 

1288 annotations[contig]["loci"][locus]["downstream"] = downstream 

1289 

1290 # ------------------------------------------------------------------------- 

1291 # Write Output 

1292 

1293 tsv_path = os.path.join(output_dir, prefix + ".tsv") 

1294 logging.info(f"Writing output tsv: {tsv_path}") 

1295 with open(tsv_path, 'w') as outfile: 

1296 header = None 

1297 for contig,contig_data in annotations.items(): 

1298 for locus,locus_data in contig_data["loci"].items(): 

1299 if locus_data["feature"] == "region": continue 

1300 if header == None: 

1301 header = list(locus_data.keys()) 

1302 outfile.write("\t".join(header) + "\n") 

1303 row = [str(v) for v in locus_data.values()] 

1304 # Restore commas from there %2C encoding 

1305 line = "\t".join(row).replace("%2C", ",") 

1306 outfile.write(line + "\n") 

1307 

1308 return tsv_path 

1309 

1310 

1311def collect( 

1312 tsv: list = None, 

1313 tsv_paths: str = None, 

1314 outdir : str = ".", 

1315 prefix : str = None, 

1316 args: str = None, 

1317 ) -> str: 

1318 """ 

1319 Collect sequences from multiple samples into one file. 

1320 

1321 Takes as input multiple TSV files from the extract subcommand, which can  

1322 be supplied as either space separate paths, or a text file containing paths.  

1323 Duplicate sequence IDs will be identified and given the suffix '.#'. 

1324 Outputs concatenated FASTA and TSV files. 

1325 

1326 >>> collect(tsv=["sample1.tsv", "sample2.tsv"]) 

1327 >>> collect(tsv_paths='paths.txt') 

1328 

1329 :param tsv: List of TSV file paths from the output of extract. 

1330 :param tsv_paths: TXT file with each line containing a TSV file path. 

1331 :param output: Path to the output directory. 

1332 :param prefix: Output file prefix. 

1333 :param args: Additional arguments [not implemented] 

1334 

1335 :return: Tuple of output FASTA and TSV paths. 

1336 """ 

1337 

1338 from Bio import SeqIO 

1339 

1340 tsv_paths, tsv_txt_path, output_dir = tsv, tsv_paths, outdir 

1341 prefix = f"{prefix}." if prefix != None else "" 

1342 

1343 # Check output directory 

1344 check_output_dir(output_dir) 

1345 

1346 # TSV file paths 

1347 tsv_file_paths = [] 

1348 if tsv_txt_path != None: 

1349 logging.info(f"Reading tsv paths from file: {tsv_txt_path}") 

1350 with open(tsv_txt_path) as infile: 

1351 for line in infile: 

1352 file_path = [l.strip() for l in line.split("\t")][0] 

1353 tsv_file_paths.append(file_path.strip()) 

1354 if tsv_paths != None: 

1355 for file_path in tsv_paths: 

1356 tsv_file_paths.append(file_path) 

1357 

1358 logging.info(f"Checking for duplicate samples and sequence IDs.") 

1359 all_samples = [] 

1360 sequence_id_counts = {} 

1361 for file_path in tqdm(tsv_file_paths): 

1362 with open(file_path) as infile: 

1363 header = [l.strip() for l in infile.readline().split("\t")] 

1364 for i,line in enumerate(infile): 

1365 row = [l.strip() for l in line.split("\t")] 

1366 data = {k:v for k,v in zip(header, row)} 

1367 

1368 # Check for duplicate samples 

1369 if i == 0: 

1370 sample = data["sample"] 

1371 if sample in all_samples: 

1372 msg = f"Duplicate sample ID found: {sample}" 

1373 logging.error(msg) 

1374 raise Exception(msg) 

1375 all_samples.append(sample) 

1376 

1377 # Check for duplicate IDs 

1378 sequence_id = data["sequence_id"] 

1379 if sequence_id not in sequence_id_counts: 

1380 sequence_id_counts[sequence_id] = {"count": 0, "current": 1} 

1381 sequence_id_counts[sequence_id]["count"] += 1 

1382 

1383 output_fasta = os.path.join(outdir, f"{prefix}sequences.fasta") 

1384 logging.info(f"Writing output fasta file to: {output_fasta}") 

1385 output_tsv = os.path.join(outdir, f"{prefix}regions.tsv") 

1386 logging.info(f"Writing output tsv file to: {output_tsv}") 

1387 

1388 fasta_outfile = open(output_fasta, 'w') 

1389 tsv_outfile = open(output_tsv, "w") 

1390 header = None 

1391 

1392 for file_path in tqdm(tsv_file_paths): 

1393 with open(file_path) as infile: 

1394 header_line = infile.readline() 

1395 if header == None: 

1396 header = [l.strip() for l in header_line.split("\t")] 

1397 tsv_outfile.write(header_line) 

1398 for line in infile: 

1399 row = [l.strip() for l in line.split("\t")] 

1400 data = {k:v for k,v in zip(header, row)} 

1401 

1402 # Handle duplicate sequence IDs with a .# suffix  

1403 sequence_id = data["sequence_id"] 

1404 count = sequence_id_counts[sequence_id]["count"] 

1405 if count > 1: 

1406 dup_i = sequence_id_counts[sequence_id]["current"] 

1407 sequence_id_counts[sequence_id]["current"] += 1 

1408 data["sequence_id"] = sequence_id = f"{sequence_id}.{dup_i}" 

1409 logging.debug(f"Duplicate sequence ID found in TSV, flagging as: {sequence_id}") 

1410 

1411 line = "\t".join([str(data[col]) for col in header]) 

1412 tsv_outfile.write(line + "\n") 

1413 

1414 sequence = data["sequence"] 

1415 line = f">{sequence_id}\n{sequence}" 

1416 fasta_outfile.write(line + "\n") 

1417 

1418 fasta_outfile.close() 

1419 tsv_outfile.close() 

1420 

1421 return (output_fasta, output_tsv) 

1422 

1423 

1424def cluster( 

1425 fasta: str, 

1426 outdir: str = ".", 

1427 prefix: str = None, 

1428 threads: int = 1, 

1429 memory: str = "1G", 

1430 tmp: str = "tmp", 

1431 clean: bool = True, 

1432 args: str = CLUSTER_ARGS, 

1433 ): 

1434 """ 

1435 Cluster nucleotide sequences with mmseqs. 

1436 

1437 Takes as input a FASTA file of sequences for clustering from collect. 

1438 Calls MMSeqs2 to cluster sequences and identify a representative sequence.  

1439 Outputs a TSV table of sequence clusters and a FASTA of representative sequences. 

1440 

1441 Note: The default kmer size (15) requires at least 9G of memory. To use less 

1442 memory, please set the kmer size to '-k 13'. 

1443 

1444 >>> cluster(fasta='sequences.fasta') 

1445 >>> cluster(fasta='sequences.fasta', threads=2, memory='2G', args='-k 13 --min-seq-id 0.90 -c 0.90') 

1446 >>> cluster(fasta='sequences.fasta', threads=4, args='-k 13 --min-seq-id 0.90 -c 0.90') 

1447 

1448 :param fasta: Path fo FASTA sequences. 

1449 :param outdir: Output directory. 

1450 :param prefix: Prefix for output files. 

1451 :param threads: CPU threads for MMSeqs2. 

1452 :param memory: Memory for MMSeqs2. 

1453 :param tmp: Path to a temporary directory. 

1454 :param clean: True if intermediate files should be cleaned up. 

1455 :param args: Additional parameters for MMSeqs2 cluster command.  

1456 """ 

1457 

1458 from Bio import SeqIO 

1459 

1460 fasta_path, tmp_dir, output_dir = fasta, tmp, outdir 

1461 prefix = f"{prefix}." if prefix != None else "" 

1462 

1463 # Wrangle the output directory 

1464 check_output_dir(output_dir) 

1465 check_output_dir(tmp_dir) 

1466 

1467 args = args if args != None else "" 

1468 

1469 # fix memory formatting (ex. '6 GB' -> '6G') 

1470 memory = memory.replace(" ", "").replace("B", "") 

1471 

1472 # 1. Cluster Sequences 

1473 seq_db = os.path.join(output_dir, f"{prefix}seqDB") 

1474 clust_db = os.path.join(output_dir, f"{prefix}clustDB") 

1475 tsv_path = os.path.join(output_dir, f"{prefix}clusters.tsv") 

1476 run_cmd(f"mmseqs createdb {fasta_path} {seq_db}") 

1477 try: 

1478 run_cmd(f"mmseqs cluster {seq_db} {clust_db} {tmp_dir} --threads {threads} --split-memory-limit {memory} {args}") 

1479 except Exception as e: 

1480 if "Segmentation fault" in f"{e}": 

1481 logging.error(f"Segmentation fault. Try changing your threads and/or memory. Memory consumption is {threads} x {memory}.") 

1482 raise Exception(e) 

1483 run_cmd(f"mmseqs createtsv {seq_db} {seq_db} {clust_db} {tsv_path} --threads {threads} --full-header") 

1484 

1485 # Sort and remove quotations in clusters 

1486 with open(tsv_path) as infile: 

1487 lines = sorted([l.strip().replace("\"", "") for l in infile.readlines()]) 

1488 with open(tsv_path, 'w') as outfile: 

1489 outfile.write("\n".join(lines) + "\n") 

1490 

1491 # 2. Identify representative sequences 

1492 rep_db = os.path.join(output_dir, f"{prefix}repDB") 

1493 rep_fasta_path = os.path.join(output_dir, f"{prefix}representative.fasta") 

1494 run_cmd(f"mmseqs result2repseq {seq_db} {clust_db} {rep_db} --threads {threads}") 

1495 run_cmd(f"mmseqs result2flat {seq_db} {seq_db} {rep_db} {rep_fasta_path} --use-fasta-header") 

1496 

1497 # Sort sequences, and remove trailing whitespace 

1498 sequences = {} 

1499 for record in SeqIO.parse(rep_fasta_path, "fasta"): 

1500 locus = record.id.strip() 

1501 seq = record.seq.strip() 

1502 sequences[locus] = seq 

1503 with open(rep_fasta_path, 'w') as outfile: 

1504 for locus in sorted(list(sequences.keys())): 

1505 seq = sequences[locus] 

1506 outfile.write(f">{locus}\n{seq}\n") 

1507 

1508 # Sort clusters 

1509 with open(tsv_path) as infile: 

1510 lines = sorted([line.strip() for line in infile]) 

1511 with open(tsv_path, 'w') as outfile: 

1512 outfile.write("\n".join(lines)) 

1513 

1514 # Cleanup 

1515 if clean == True: 

1516 for file_name in os.listdir(output_dir): 

1517 file_path = os.path.join(output_dir, file_name) 

1518 db_prefix = os.path.join(output_dir, prefix) 

1519 if ( 

1520 file_path.startswith(f"{db_prefix}seqDB") or 

1521 file_path.startswith(f"{db_prefix}clustDB") or 

1522 file_path.startswith(f"{db_prefix}repDB") 

1523 ): 

1524 logging.info(f"Cleaning up file: {file_path}") 

1525 os.remove(file_path) 

1526 

1527 return(tsv_path, rep_fasta_path) 

1528 

1529 

1530def defrag( 

1531 clusters: str, 

1532 representative: str, 

1533 outdir: str = ".", 

1534 prefix: str = None, 

1535 tmp: str = "tmp", 

1536 threads: str = 1, 

1537 memory: str = "2G", 

1538 clean: str = True, 

1539 args: str = DEFRAG_ARGS, 

1540 ) -> OrderedDict: 

1541 """ 

1542 Defrag clusters by associating fragments with their parent cluster. 

1543 

1544 Takes as input the TSV clusters and FASTA representatives from the cluster subcommand. 

1545 Outputs a new cluster table and representative sequences fasta. 

1546 

1547 This is a modification of ppanggolin's refine_clustering function: 

1548 https://github.com/labgem/PPanGGOLiN/blob/2.2.0/ppanggolin/cluster/cluster.py#L317 

1549 

1550 >>> defrag(clusters='clusters.tsv', representative='representative.fasta', prefix="defrag") 

1551 >>> defrag(clusters='clusters.tsv', representative='representative.fasta', threads=2, memory='2G', args="-k 13 --min-seq-id 0.90 -c 0.90 --cov-mode 1") 

1552 

1553 :param clusters: TSV file of clusters. 

1554 :param representative: FASTA file of representative sequences from cluster (mmseqs) 

1555 :param outdir: Path to the output directory. 

1556 :param prefix: Prefix for output files. 

1557 :param tmp: Path to a temporary directory. 

1558 :param threads: CPU threads for mmseqs. 

1559 :param memory: Memory for mmseqs. 

1560 :param clean: True if intermediate DB files should be cleaned up. 

1561 :param args: Additional arguments for `mmseqs search`. 

1562 

1563 

1564 :return: Ordered Dictionary of new defragmented clusters 

1565 """ 

1566 

1567 from Bio import SeqIO 

1568 

1569 clusters_path, representative_path, output_dir, tmp_dir = clusters, representative, outdir, tmp 

1570 prefix = f"{prefix}." if prefix != None else "" 

1571 

1572 # Check output directory 

1573 check_output_dir(output_dir) 

1574 check_output_dir(tmp_dir) 

1575 

1576 args = args if args != None else "" 

1577 

1578 # fix memory formatting (ex. '6 GB' -> '6G') 

1579 memory = memory.replace(" ", "").replace("B", "") 

1580 

1581 # ------------------------------------------------------------------------- 

1582 # Align representative sequences against each other 

1583 

1584 logging.info("Aligning representative sequences.") 

1585 

1586 seq_db = os.path.join(output_dir, f"{prefix}seqDB") 

1587 aln_db = os.path.join(output_dir, f"{prefix}alnDB") 

1588 tsv_path = os.path.join(output_dir, f"{prefix}align.tsv") 

1589 run_cmd(f"mmseqs createdb {representative_path} {seq_db}") 

1590 run_cmd(f"mmseqs search {seq_db} {seq_db} {aln_db} {tmp_dir} --threads {threads} --split-memory-limit {memory} --search-type 3 {args}") 

1591 columns="query,target,fident,alnlen,mismatch,gapopen,qstart,qend,qlen,tstart,tend,tlen,evalue,bits" 

1592 run_cmd(f"mmseqs convertalis {seq_db} {seq_db} {aln_db} {tsv_path} --search-type 3 --format-output {columns}") 

1593 

1594 # Sort align rep stats for reproducibility 

1595 with open(tsv_path, 'r') as infile: 

1596 lines = sorted([l.strip() for l in infile.readlines()]) 

1597 with open(tsv_path, 'w') as outfile: 

1598 header = "\t".join(columns.split(",")) 

1599 outfile.write(header + "\n") 

1600 outfile.write("\n".join(lines) + "\n") 

1601 

1602 

1603 # ------------------------------------------------------------------------- 

1604 logging.info(f"Reading clusters: {clusters_path}") 

1605 loci = OrderedDict() 

1606 clusters = OrderedDict() 

1607 with open(clusters_path) as infile: 

1608 lines = infile.readlines() 

1609 for line in tqdm(lines): 

1610 cluster, locus = [l.strip() for l in line.split("\t")] 

1611 loci[locus] = cluster 

1612 if cluster not in clusters: 

1613 clusters[cluster] = [] 

1614 if locus not in clusters[cluster]: 

1615 clusters[cluster].append(locus) 

1616 

1617 # ------------------------------------------------------------------------- 

1618 logging.info(f"Reading representative: {representative_path}") 

1619 representative = OrderedDict() 

1620 for record in SeqIO.parse(representative_path, "fasta"): 

1621 representative[record.id] = record.seq 

1622 

1623 # ------------------------------------------------------------------------- 

1624 # Similarity Graph 

1625 

1626 # Create a graph of cluster relationships, based on the alignment of their 

1627 # representative sequences. The edges between clusters will represent the 

1628 # pairwise alignment score (bits). 

1629 

1630 # This is a modification of ppanggolin's refine_clustering function: 

1631 # https://github.com/labgem/PPanGGOLiN/blob/2.2.0/ppanggolin/cluster/cluster.py#L317 

1632 # The major difference is that the original function imposes a constraint 

1633 # that the fragmented cluster cannot contain more loci than the new parent.  

1634 # This function does not use the number of loci, which allows a cluster 

1635 # with many small fragments to be reassigned to a longer parent that might 

1636 # be represented by only one intact sequence. This function also prefers 

1637 # a OrderedDict over a formal graph object, so we can avoid the networkx 

1638 # dependency. 

1639 

1640 logging.info(f"Creating similarity graph from alignment: {tsv_path}") 

1641 

1642 graph = OrderedDict() 

1643 

1644 with open(tsv_path) as infile: 

1645 header = infile.readline().split() 

1646 cluster_i, locus_i, qlen_i, tlen_i, bits_i = [header.index(c) for c in ["query", "target", "qlen", "tlen", "bits"]] 

1647 lines = infile.readlines() 

1648 for line in tqdm(lines): 

1649 row = [r.strip() for r in line.replace('"', '').split()] 

1650 if row == []: continue 

1651 query, target, qlen, tlen, bits = [row[i] for i in [cluster_i, locus_i, qlen_i, tlen_i, bits_i]] 

1652 if query != target: 

1653 if query not in graph: 

1654 graph[query] = OrderedDict() 

1655 if target not in graph: 

1656 graph[target] = OrderedDict() 

1657 graph[query][target] = {"score": float(bits), "length": int(qlen)} 

1658 graph[target][query] = {"score": float(bits), "length": int(tlen)} 

1659 

1660 # ------------------------------------------------------------------------- 

1661 logging.info(f"Identifying fragmented loci.") 

1662 

1663 # Reassign fragmented loci to their new 'parent' cluster which must: 

1664 # 1. Be longer (length) 

1665 # 2. Have a higher score. 

1666 

1667 defrag_clusters = OrderedDict({c:{ 

1668 "loci": clusters[c], 

1669 "sequence": representative[c], 

1670 "fragments": []} 

1671 for c in clusters 

1672 }) 

1673 

1674 reassigned = {} 

1675 

1676 nodes = list(graph.keys()) 

1677 

1678 for node in tqdm(nodes): 

1679 # Get the current parent node 

1680 candidate_node = None 

1681 candidate_score = 0 

1682 # Iterate through the candidate targets (clusters it could be aligned against) 

1683 logging.debug(f"query: {node}") 

1684 for target in graph[node]: 

1685 ndata = graph[node][target] 

1686 tdata = graph[target][node] 

1687 nlen, tlen, tscore = ndata["length"], tdata["length"], tdata["score"] 

1688 # Compare lengths and scores 

1689 logging.debug(f"\ttarget: {target}, tlen: {tlen}, nlen: {nlen}, tscore: {tscore}, cscore: {candidate_score}") 

1690 if tlen > nlen and candidate_score < tscore: 

1691 candidate_node = target 

1692 candidate_score = tscore 

1693 

1694 # Check candidate 

1695 if candidate_node is not None: 

1696 # candidate node might have also been a fragment that got reassigned 

1697 while candidate_node not in defrag_clusters and candidate_node in reassigned: 

1698 new_candidate_node = reassigned[candidate_node] 

1699 logging.debug(f"Following fragments: {node}-->{candidate_node}-->{new_candidate_node}") 

1700 candidate_node = new_candidate_node 

1701 for locus in clusters[node]: 

1702 defrag_clusters[candidate_node]["loci"].append(locus) 

1703 defrag_clusters[candidate_node]["fragments"].append(locus) 

1704 del defrag_clusters[node] 

1705 reassigned[node] = candidate_node 

1706 

1707 # Sort order for reproducibility 

1708 defrag_cluster_order = sorted(list(defrag_clusters.keys())) 

1709 

1710 defrag_clusters_path = os.path.join(output_dir, f"{prefix}clusters.tsv") 

1711 defrag_rep_path = os.path.join(output_dir, f"{prefix}representative.fasta") 

1712 

1713 with open(defrag_clusters_path, 'w') as clust_file: 

1714 with open(defrag_rep_path, 'w') as seq_file: 

1715 for cluster in defrag_cluster_order: 

1716 info = defrag_clusters[cluster] 

1717 # Write sequence 

1718 seq_file.write(f">{cluster}\n{info['sequence']}" + "\n") 

1719 # Write cluster loci, we sort for reproducibility 

1720 for locus in sorted(info["loci"]): 

1721 fragment = "F" if locus in info["fragments"] else "" 

1722 line = f"{cluster}\t{locus}\t{fragment}" 

1723 clust_file.write(line + '\n') 

1724 

1725 # Cleanup 

1726 if clean == True: 

1727 for file_name in os.listdir(output_dir): 

1728 file_path = os.path.join(output_dir, file_name) 

1729 db_prefix = os.path.join(output_dir, prefix) 

1730 if ( 

1731 file_path.startswith(f"{db_prefix}seqDB") or 

1732 file_path.startswith(f"{db_prefix}alnDB") 

1733 ): 

1734 logging.info(f"Cleaning up file: {file_path}") 

1735 os.remove(file_path) 

1736 

1737 logging.info(f"IMPORTANT!\n{PPANGGOLIN_NOTICE}") 

1738 

1739 return (defrag_clusters_path, defrag_rep_path) 

1740 

1741 

1742def summarize( 

1743 clusters: str, 

1744 regions: str, 

1745 outdir: str = ".", 

1746 product_clean: dict = { 

1747 " ": "_", 

1748 "putative" : "", 

1749 "hypothetical": "", 

1750 "(": "", 

1751 ")": "", 

1752 ",": "", 

1753 }, 

1754 max_product_len: int = 50, 

1755 min_samples: int = 1, 

1756 threshold: float = 0.5, 

1757 prefix: str = None, 

1758 args: str = None, 

1759 ): 

1760 """ 

1761 Summarize clusters according to their annotations. 

1762 

1763 Takes as input the clusters TSV from either the cluster or defrag subcommand, 

1764 and the TSV table of annotations from the collect subcommand. 

1765 Outputs a TSV table of clusters and their summarized annotations. 

1766 

1767 >>> summarize(clusters="clusters.tsv", sequences="sequences.tsv", prefix="summarize") 

1768 

1769 :param clusters: TSV file of clusters from cluster or defrag (mmseqs). 

1770 :param sequences: TSV file of sequences from collect. 

1771 :param outdir: Output directory. 

1772 :param product: Remove these words from the product description when it is the identifier. 

1773 :param max_product_len: Truncate product to this length if it's being used as an identifier. 

1774 :param prefix: Prefix for output files. 

1775 :param args: Str of additional arguments for mmseqs search.  

1776 

1777 :return: Ordered Dictionary of summarized clusters 

1778 """ 

1779 

1780 import networkx 

1781 from networkx.exception import NetworkXNoCycle 

1782 

1783 sequences_path, clusters_path, output_dir = regions, clusters, outdir 

1784 prefix = f"{prefix}." if prefix != None else "" 

1785 check_output_dir(output_dir) 

1786 

1787 args = args if args != None else "" 

1788 

1789 # Type conversions as fallback 

1790 threshold = float(threshold) 

1791 max_product_len, min_samples = int(max_product_len), int(min_samples) 

1792 

1793 # ------------------------------------------------------------------------- 

1794 # Read Sequence records 

1795 

1796 logging.info(f"Reading sequence regions: {sequences_path}") 

1797 sequences = OrderedDict() 

1798 all_samples = [] 

1799 with open(sequences_path) as infile: 

1800 header = [line.strip() for line in infile.readline().split("\t")] 

1801 lines = infile.readlines() 

1802 for line in tqdm(lines): 

1803 row = [l.strip() for l in line.split("\t")] 

1804 data = {k:v for k,v in zip(header, row)} 

1805 sample, sequence_id = data["sample"], data["sequence_id"] 

1806 if sample not in all_samples: 

1807 all_samples.append(sample) 

1808 if sequence_id not in sequences: 

1809 sequences[sequence_id] = data 

1810 # At this point, we should not have duplicate sequence IDs 

1811 # That would have been handled in the extract/collect command 

1812 else: 

1813 msg = f"Duplicate sequence ID found: {sequence_id}" 

1814 logging.error(msg) 

1815 raise Exception(msg) 

1816 

1817 # ------------------------------------------------------------------------- 

1818 # Read Clusters 

1819 

1820 logging.info(f"Reading clusters: {clusters_path}") 

1821 seen = set() 

1822 clusters = OrderedDict() 

1823 representative_to_cluster = {} 

1824 i = 0 

1825 

1826 with open(clusters_path) as infile: 

1827 lines = infile.readlines() 

1828 for line in tqdm(lines): 

1829 

1830 # Extract the 3 columns: cluster representative, sequence Id, and (optional) fragment 

1831 row = [l.strip() for l in line.split("\t")] 

1832 representative, sequence_id = row[0], row[1] 

1833 fragment = True if len(row) > 2 and row[2] == "F" else False 

1834 if representative not in seen: 

1835 i += 1 

1836 cluster = f"Cluster_{i}" 

1837 seen.add(representative) 

1838 representative_to_cluster[representative] = cluster 

1839 

1840 if sequence_id not in sequences: 

1841 msg = f"Sequence is present in clusters but not in regions: {sequence_id}" 

1842 logging.error(msg) 

1843 raise Exception(msg) 

1844 sequences[sequence_id]["cluster"] = cluster 

1845 

1846 # Start the cluster data 

1847 if cluster not in clusters: 

1848 clusters[cluster] = OrderedDict({ 

1849 "cluster": cluster, 

1850 "representative": representative, 

1851 "sequences": OrderedDict() 

1852 }) 

1853 if sequence_id not in clusters[cluster]["sequences"]: 

1854 clusters[cluster]["sequences"][sequence_id] = sequences[sequence_id] 

1855 clusters[cluster]["sequences"][sequence_id]["fragment"] = fragment 

1856 else: 

1857 msg = f"Duplicate cluster sequence ID found: {sequence_id}" 

1858 logging.error(msg) 

1859 raise Exception(msg) 

1860 

1861 logging.info(f"Found {len(clusters)} clusters.") 

1862 

1863 # ------------------------------------------------------------------------- 

1864 # Summarize 

1865 # -------------------------------------------------------------------------  

1866 

1867 logging.info(f"Summarizing clusters.") 

1868 summarized = OrderedDict() 

1869 gene_counts = {} 

1870 product_counts = {} 

1871 

1872 for cluster,cluster_data in tqdm(clusters.items()): 

1873 

1874 # Identify number of samples 

1875 cluster_samples = set() 

1876 for s in cluster_data["sequences"].values(): 

1877 cluster_samples.add(s["sample"]) 

1878 num_samples = len(list(cluster_samples)) 

1879 # If user requested a minimum number of samples 

1880 if num_samples < min_samples: 

1881 continue 

1882 samples_non_fragmented = set() 

1883 sequences_per_sample = {} 

1884 summarized[cluster] = OrderedDict({ 

1885 "cluster": cluster, 

1886 "cluster_id": cluster, 

1887 "synteny": "", 

1888 "synteny_pos": "", 

1889 "num_samples": num_samples, 

1890 "num_samples_non_fragmented": 0, 

1891 "num_sequences": 0, 

1892 "mean_sequences_per_sample": 0, 

1893 "representative": cluster_data["representative"], 

1894 "upstream": "", 

1895 "downstream": "", 

1896 "upstream_alt": "", 

1897 "downstream_alt": "" 

1898 } 

1899 ) 

1900 cluster_sequences = cluster_data["sequences"] 

1901 features = Counter() 

1902 genes = Counter() 

1903 products = Counter() 

1904 names = Counter() 

1905 dbxrefs = Counter() 

1906 strands = Counter() 

1907 contexts = Counter() 

1908 contexts_uniq = Counter() 

1909 

1910 # ------------------------------------------------------------------------- 

1911 # Collect sequence attributes 

1912 

1913 for sequence_id,seq_data in cluster_sequences.items(): 

1914 

1915 summarized[cluster]["num_sequences"] += 1 

1916 

1917 sample = seq_data["sample"] 

1918 if sample not in sequences_per_sample: 

1919 sequences_per_sample[sample] = 0 

1920 sequences_per_sample[sample] += 1 

1921 

1922 # Ignore annotations from fragments starting here 

1923 if seq_data["fragment"] == True: continue 

1924 

1925 samples_non_fragmented.add(sample) 

1926 features[seq_data["feature"]] += 1 

1927 strands[seq_data["strand"]] += 1 

1928 

1929 # collect the annotation attributes 

1930 attributes = {a.split("=")[0]:a.split("=")[1] for a in seq_data["attributes"].split(";") if "=" in a} 

1931 if "gene" in attributes: 

1932 genes[attributes["gene"]] += 1 

1933 if "product" in attributes: 

1934 products[attributes["product"]] += 1 

1935 if "Name" in attributes: 

1936 names[attributes["Name"]] += 1 

1937 if "Dbxref" in attributes: 

1938 dbxrefs[attributes["Dbxref"]] += 1 

1939 

1940 # Get the upstream/downstream locus IDs 

1941 upstream, downstream = seq_data["upstream"], seq_data["downstream"] 

1942 

1943 # Simply upstream/downstream loci if they represent the start/end of a contig 

1944 upstream = "TERMINAL" if upstream.endswith("_TERMINAL") and "__" not in upstream else upstream 

1945 downstream = "TERMINAL" if downstream.endswith("_TERMINAL") and "__" not in downstream else downstream 

1946 

1947 # Get the upstream/downstream cluster IDs 

1948 # There is a possibility that the upstream/downstream sequences 

1949 # didn't actually get classified into a cluster 

1950 if upstream in sequences and "cluster" in sequences[upstream]: 

1951 upstream = sequences[upstream]["cluster"] 

1952 if downstream in sequences and "cluster" in sequences[downstream]: 

1953 downstream = sequences[downstream]["cluster"] 

1954 

1955 # In what cases would it be itself? 

1956 if upstream != cluster and downstream != cluster: 

1957 context = [upstream, downstream] 

1958 contexts["__".join(context)] += 1 

1959 contexts_uniq["__".join(sorted(context))] += 1 

1960 

1961 

1962 num_samples_non_fragmented = len(samples_non_fragmented) 

1963 summarized[cluster]["num_samples_non_fragmented"] = num_samples_non_fragmented 

1964 

1965 mean_sequences_per_sample = sum(sequences_per_sample.values()) / len(sequences_per_sample) 

1966 summarized[cluster]["mean_sequences_per_sample"] = round(mean_sequences_per_sample, 1) 

1967 

1968 # ------------------------------------------------------------------------- 

1969 # Summarize upstream/downstream 

1970 

1971 # Part 1. Are there neighboring loci (regardless of directionality)? 

1972 neighbors = None 

1973 if len(contexts_uniq) > 1: 

1974 most_common, count = contexts_uniq.most_common(1)[0] 

1975 # Are these loci observed frequently enough? 

1976 prop = count / num_samples_non_fragmented 

1977 if prop >= threshold: 

1978 # Handle ties, by simply picking the first option alphabetically 

1979 candidates = sorted([c for c,v in contexts_uniq.items() if v == count]) 

1980 neighbors = candidates[0] 

1981 if len(candidates) > 1: 

1982 logging.debug(f"{cluster} tie broken alphabetically: neighbors={neighbors}, {contexts_uniq}") 

1983 

1984 # Part 2. Filter the neighbors to our top matches, now allowing either direction 

1985 # We will summarize in the next step 

1986 if neighbors != None: 

1987 c1, c2 = most_common.split("__") 

1988 forward, reverse = f"{c1}__{c2}", f"{c2}__{c1}" 

1989 contexts = Counter({k:v for k,v in contexts.items() if k == forward or k == reverse}) 

1990 

1991 

1992 # ------------------------------------------------------------------------- 

1993 # Summarize sequence attributes 

1994 

1995 for key,values in zip(["feature", "strand", "gene", "product", "name", "dbxref", "contexts"], [features, strands, genes, products, names, dbxrefs, contexts]): 

1996 value = "" 

1997 value_alt = "" 

1998 

1999 if len(values) > 0: 

2000 # Check if the most common value passes the threshold 

2001 most_common_count = values.most_common(1)[0][1] 

2002 most_common_prop = most_common_count / num_samples_non_fragmented 

2003 # We don't do a second threshold filter for contexts 

2004 if key == "contexts" or most_common_prop >= threshold : 

2005 # Handle ties, by simply picking the first option alphabetically 

2006 candidates = sorted([c for c,v in values.items() if v == most_common_count]) 

2007 value = candidates[0] 

2008 if len(candidates) > 1: 

2009 logging.debug(f"{cluster} tie broken alphabetically: {key}={value}, {values}") 

2010 value_alt = ";".join([v for v in values.keys() if v != value]) 

2011 else: 

2012 value_alt = ";".join([v for v in values.keys()]) 

2013 

2014 if key == "contexts": 

2015 if value != "": 

2016 upstream, downstream = value.split("__") 

2017 summarized[cluster]["upstream"] = upstream 

2018 summarized[cluster]["downstream"] = downstream 

2019 continue 

2020 

2021 summarized[cluster][key] = value 

2022 summarized[cluster][f"{key}_alt"] = value_alt 

2023 

2024 # gene/product identifiers need to be checked for case! 

2025 

2026 if key == "gene" and value != "": 

2027 value_lower = value.lower() 

2028 if value_lower not in gene_counts: 

2029 gene_counts[value_lower] = {"count": 0, "current": 1} 

2030 gene_counts[value_lower]["count"] += 1 

2031 

2032 if key == "product": 

2033 summarized[cluster]["product_clean"] = "" 

2034 if value != "": 

2035 # Clean up product as a potential identifier 

2036 # These values are definitely not allowed 

2037 clean_value = value.replace(" ", "_").replace("/", "_").replace(",", "_") 

2038 for k,v in product_clean.items(): 

2039 clean_value = clean_value.replace(k, v) 

2040 # Restrict the length 

2041 if max_product_len != None: 

2042 clean_value = clean_value[:max_product_len] 

2043 while "__" in clean_value: 

2044 clean_value = clean_value.replace("__", "_") 

2045 while clean_value.startswith("_") or clean_value.endswith("_"): 

2046 clean_value = clean_value.lstrip("_").rstrip("_") 

2047 if clean_value != "": 

2048 clean_value_lower = clean_value.lower() 

2049 if clean_value_lower not in product_counts: 

2050 product_counts[clean_value_lower] = {"count": 0, "current": 1} 

2051 product_counts[clean_value_lower]["count"] += 1 

2052 summarized[cluster]["product_clean"] = clean_value 

2053 

2054 # Handle if feature was 'unannotated' 

2055 # Can happen in tie breaking situations 

2056 if ( 

2057 (key == "gene" or key == "product") and 

2058 value != "" and 

2059 summarized[cluster]["feature"] == "unannotated" 

2060 ): 

2061 candidates = [f for f in summarized[cluster]["feature_alt"].split(";") if f != "unannotated"] 

2062 feature = candidates[0] 

2063 summarized[cluster]["feature"] = feature 

2064 features_alt = ["unannotated"] + [c for c in candidates if c != feature] 

2065 summarized[cluster]["feature_alt"] = ";".join(features_alt) 

2066 logging.debug(f"Updating {cluster} from an unannotated feature to {feature} based on {key}={value}.") 

2067 

2068 summarized[cluster]["sequences"] = cluster_sequences 

2069 

2070 # ------------------------------------------------------------------------- 

2071 # Identifiers: Part 1 

2072 # -------------------------------------------------------------------------  

2073 

2074 # Give identifiers to clusters based on their gene or product. 

2075 # We do this now, so that the synteny graph has some nice  

2076 # helpful names for clusters. We will give names to the  

2077 # unnanotated clusters based on their upstream/downstream 

2078 # loci after the synteny graph is made. 

2079 

2080 logging.info(f"Assigning identifiers to annotated clusters.") 

2081 

2082 identifiers = OrderedDict() 

2083 

2084 # Pass 1. Identifiers based on gene/product 

2085 for cluster,cluster_data in tqdm(summarized.items()): 

2086 identifier = None 

2087 # Option 1. Try to use gene as cluster identifier 

2088 if cluster_data["gene"] != "": 

2089 gene = cluster_data["gene"] 

2090 # we use the lowercase to figure out the duplicate 

2091 # number, but use the original name in the table 

2092 gene_lower = gene.lower() 

2093 if gene_counts[gene_lower]["count"] == 1: 

2094 identifier = gene 

2095 else: 

2096 dup_i = gene_counts[gene_lower]["current"] 

2097 gene_counts[gene_lower]["current"] += 1 

2098 new_gene = f"{gene}.{dup_i}" 

2099 logging.debug(f"Duplicate gene identifer found for {cluster}: {new_gene}") 

2100 identifier = new_gene 

2101 

2102 # Option 2. Try to use product as cluster identifier 

2103 # Useful for non-gene annotations (ex. tRNA) 

2104 elif cluster_data["product_clean"] != "": 

2105 product = cluster_data["product_clean"] 

2106 product_lower = product.lower() 

2107 if product_counts[product_lower]["count"] == 1: 

2108 identifier = product 

2109 else: 

2110 dup_i = product_counts[product_lower]["current"] 

2111 product_counts[product_lower]["current"] += 1 

2112 new_product = f"{product}.{dup_i}" 

2113 logging.debug(f"Duplicate product identifer found for {cluster}: {new_product}") 

2114 identifier = new_product 

2115 

2116 if identifier != None: 

2117 identifiers[cluster] = identifier 

2118 summarized[cluster]["cluster"] = identifier 

2119 

2120 # Pass 2: Update the upstream/downstream identifiers 

2121 for cluster,cluster_data in tqdm(summarized.items()): 

2122 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"] 

2123 if upstream in identifiers: 

2124 summarized[cluster]["upstream"] = identifiers[upstream] 

2125 if downstream in identifiers: 

2126 summarized[cluster]["downstream"] = identifiers[downstream] 

2127 

2128 # Update the cluster keys 

2129 summarized = OrderedDict({v["cluster"]:v for k,v in summarized.items()}) 

2130 

2131 # ------------------------------------------------------------------------- 

2132 # Synteny 

2133 # ------------------------------------------------------------------------- 

2134 

2135 logging.info(f"Computing initial synteny graph.") 

2136 

2137 # -------------------------------------------------------------------------  

2138 # Create initial graph  

2139 

2140 # Create simple graph based on what we know about the 

2141 # upstream/downstream loci so far. This will be our "full" 

2142 # graph which retains all the cycles and multifurcations 

2143 synteny_full = networkx.DiGraph() 

2144 seen_nodes = set() 

2145 

2146 for cluster, cluster_data in summarized.items(): 

2147 if cluster not in seen_nodes: 

2148 synteny_full.add_node(cluster) 

2149 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"] 

2150 if upstream != "" and upstream != "TERMINAL": 

2151 synteny_full.add_edge(upstream, cluster) 

2152 if downstream != "" and downstream != "TERMINAL": 

2153 synteny_full.add_edge(cluster, downstream) 

2154 

2155 # -------------------------------------------------------------------------  

2156 # Filter nodes 

2157 

2158 # Remove clusters from the graph. This can happen if the  

2159 # user requested min_samples>0, so we need to remove the  

2160 # low prevalence clusters. It can also happen if a  

2161 # cluster had an upstream/downstream loci that didn't 

2162 # make it to the final clustering list 

2163 logging.info(f"Filtering graph for missing clusters.") 

2164 for node in list(networkx.dfs_tree(synteny_full)): 

2165 if node in summarized: continue 

2166 in_nodes = [e[0] for e in synteny_full.in_edges(node)] 

2167 out_nodes = [e[1] for e in synteny_full.out_edges(node)] 

2168 # Remove this node from the graph 

2169 logging.debug(f"Removing {node} from the graph.") 

2170 synteny_full.remove_node(node) 

2171 # Connect new edges between in -> out 

2172 for n1 in [n for n in in_nodes if n not in out_nodes]: 

2173 for n2 in [n for n in out_nodes if n not in in_nodes]: 

2174 # Not sure if we need this check? 

2175 if not synteny_full.has_edge(n1, n2): 

2176 synteny_full.add_edge(n1, n2) 

2177 

2178 # -------------------------------------------------------------------------  

2179 # Break up multifurcations 

2180 

2181 synteny_linear = copy.deepcopy(synteny_full) 

2182 

2183 logging.info(f"Breaking up multifurcations.") 

2184 for node in tqdm(list(networkx.dfs_tree(synteny_linear))): 

2185 in_nodes = [e[0] for e in synteny_linear.in_edges(node)] 

2186 out_nodes = [e[1] for e in synteny_linear.out_edges(node)] 

2187 

2188 if len(in_nodes) > 1: 

2189 for n in in_nodes: 

2190 logging.debug(f"Removing multifurcation in_edge: {n} -> {node}") 

2191 synteny_linear.remove_edge(n, node) 

2192 if len(out_nodes) > 1: 

2193 for n in out_nodes: 

2194 logging.debug(f"Removing multifurcation out_edge: {node} -> {n}") 

2195 synteny_linear.remove_edge(node, n) 

2196 

2197 # -------------------------------------------------------------------------  

2198 # Isolate and linear cycles 

2199 

2200 # I'm not sure in what cases this is still needed 

2201 # after switching the graph to directed 

2202 logging.info(f"Isolating cycles and linearizing.") 

2203 cycles = True 

2204 while cycles == True: 

2205 try: 

2206 cycle_raw = [x for xs in networkx.find_cycle(synteny_linear) for x in xs] 

2207 seen = set() 

2208 cycle = [] 

2209 for n in cycle_raw: 

2210 if n not in seen: 

2211 cycle.append(n) 

2212 seen.add(n) 

2213 logging.debug(f"Cycle found: {cycle}") 

2214 cycle_set = set(cycle) 

2215 # Separate this cycle from the rest of the graph 

2216 for n1 in synteny_linear.nodes(): 

2217 if n1 in cycle: continue 

2218 neighbors = set(synteny_linear.neighbors(n1)) 

2219 for n2 in cycle_set.intersection(neighbors): 

2220 logging.debug(f"Isolating cycle by removing edge: {n1} <-> {n2}") 

2221 synteny_linear.remove_edge(n1, n2) 

2222 

2223 # Break the final cycle between the 'first' and 'last' nodes 

2224 first, last = cycle[0], cycle[-1] 

2225 logging.debug(f"Breaking cycle by removing edge: {last} -> {first}") 

2226 synteny_linear.remove_edge(last, first) 

2227 

2228 except NetworkXNoCycle: 

2229 cycles = False 

2230 

2231 # -------------------------------------------------------------------------  

2232 # Identify synteny blocks 

2233 

2234 # We need to use the function connected_components, but that only  

2235 # works on undirected graphs 

2236 synteny_linear_undirected = synteny_linear.to_undirected() 

2237 

2238 # Get synteny blocks, sorted largest to smallest 

2239 logging.info(f"Identifying synteny blocks.") 

2240 synteny_blocks = sorted([ 

2241 synteny_linear_undirected.subgraph(c).copy() 

2242 for c in networkx.connected_components(synteny_linear_undirected) 

2243 ], key=len, reverse=True) 

2244 

2245 # Now we have enough information to finalize the order 

2246 summarized_order = OrderedDict() 

2247 

2248 for i_b, block in enumerate(tqdm(synteny_blocks)): 

2249 i_b += 1 

2250 clusters = list(block) 

2251 

2252 if len(clusters) > 1: 

2253 terminals = [n for n in block.nodes if len(list(block.neighbors(n))) == 1] 

2254 # Legacy error from troubleshooting, unclear if still relevant  

2255 if len(terminals) != 2: 

2256 # Check if it's a cycle we somehow didn't handle 

2257 try: 

2258 cycle = networkx.find_cycle(block) 

2259 cycle = True 

2260 except NetworkXNoCycle: 

2261 cycle = False 

2262 msg = f"Synteny block {i_b} has an unhandled error, cycle={cycle}, terminals={len(terminals)}: {terminals}" 

2263 logging.error(msg) 

2264 print(networkx.write_network_text(block)) 

2265 raise Exception(msg) 

2266 

2267 # Figure out which terminal is the 5' end 

2268 first_upstream = summarized[terminals[0]]["upstream"] 

2269 last_upstream = summarized[terminals[-1]]["upstream"] 

2270 

2271 if first_upstream == "TERMINAL": 

2272 first, last = terminals[0], terminals[1] 

2273 elif last_upstream == "TERMINAL": 

2274 first, last = terminals[1], terminals[0] 

2275 else: 

2276 # If it's fully ambiguous, we'll sort the terminals for reproducibility 

2277 terminals = sorted(terminals) 

2278 first, last = terminals[0], terminals[1] 

2279 

2280 # Manually walk through the graph, this is not ideal 

2281 # but networkx's bfs/dfs was giving me odd values in testing 

2282 clusters, neighbors = [first], [] 

2283 curr_node, next_node = first, None 

2284 while next_node != last: 

2285 neighbors = [n for n in block.neighbors(curr_node) if n not in clusters] 

2286 # Legacy error from troubleshooting, unclear if still relevant 

2287 if len(neighbors) != 1: 

2288 msg = f"Synteny error, unhandled multifurcation in {curr_node}: {neighbors}" 

2289 logging.error(msg) 

2290 raise Exception(msg) 

2291 next_node = neighbors[0] 

2292 clusters.append(next_node) 

2293 curr_node = next_node 

2294 

2295 for i_c, cluster in enumerate(clusters): 

2296 summarized_order[cluster] = summarized[cluster] 

2297 summarized_order[cluster]["synteny"] = str(i_b) 

2298 summarized_order[cluster]["synteny_pos"] = str(i_c + 1) 

2299 

2300 upstream = summarized_order[cluster]["upstream"] 

2301 downstream = summarized_order[cluster]["downstream"] 

2302 

2303 # Use the synteny block to finalize upstream/downstream 

2304 upstream = summarized_order[cluster]["upstream"] 

2305 downstream = summarized_order[cluster]["downstream"] 

2306 

2307 # We'll save a copy of how it was before as the alt 

2308 upstream_orig, downstream_orig = copy.deepcopy(upstream), copy.deepcopy(downstream) 

2309 

2310 # 'TERMINAL' will now refer to the ends of the synteny block 

2311 if i_c > 0: 

2312 upstream = clusters[i_c - 1] 

2313 else: 

2314 upstream = "TERMINAL" 

2315 

2316 if i_c < (len(clusters) - 1): 

2317 downstream = clusters[i_c + 1] 

2318 else: 

2319 downstream = "TERMINAL" 

2320 

2321 if upstream != upstream_orig: 

2322 summarized_order[cluster]["upstream_alt"] = upstream_orig 

2323 if downstream != downstream_orig: 

2324 summarized_order[cluster]["downstream_alt"] = downstream_orig 

2325 

2326 summarized_order[cluster]["upstream"] = upstream 

2327 summarized_order[cluster]["downstream"] = downstream 

2328 

2329 summarized = summarized_order 

2330 

2331 # ------------------------------------------------------------------------- 

2332 # Create directed 

2333 

2334 logging.info(f"Converting synteny blocks to directed graph.") 

2335 

2336 # Now we need to go back to a directed graph 

2337 synteny_linear_directed = networkx.DiGraph() 

2338 synteny_seen = set() 

2339 

2340 for cluster,cluster_data in summarized.items(): 

2341 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"] 

2342 if cluster not in synteny_seen: 

2343 synteny_seen.add(cluster) 

2344 synteny_linear_directed.add_node(cluster) 

2345 

2346 if upstream != "TERMINAL": 

2347 if upstream not in synteny_seen: 

2348 synteny_seen.add(upstream) 

2349 synteny_linear_directed.add_edge(upstream, cluster) 

2350 

2351 if downstream != "TERMINAL": 

2352 if downstream not in synteny_seen: 

2353 synteny_seen.add(downstream) 

2354 synteny_linear_directed.add_edge(cluster, downstream) 

2355 

2356 # ------------------------------------------------------------------------- 

2357 # Identifiers: Part 2 

2358 # ------------------------------------------------------------------------- 

2359 

2360 # Give unannotated clusters identifiers based on their upstream/downstream 

2361 # loci in the synteny graph. 

2362 

2363 # Note: the synteny reconstruction ensures that we're not going to have 

2364 # any duplicate cluster IDs for the unannotated regions 

2365 # because any loops/multifurcations have been removed. 

2366 

2367 logging.info(f"Assigning identifiers to unannotated clusters.") 

2368 

2369 # ------------------------------------------------------------------------- 

2370 # Pass #1: Give identifiers to unannotated clusters based on upstream/downstream 

2371 

2372 for cluster,cluster_data in tqdm(summarized.items()): 

2373 # Skip this cluster if it already has a new identifier 

2374 # Example, annotated clusters based on gene/product 

2375 if cluster in identifiers and identifiers[cluster] != cluster: 

2376 continue 

2377 # Skip this cluster if it has gene/product info, just a safety fallback 

2378 if cluster_data["gene"] != "" or cluster_data["product"] != "": 

2379 continue 

2380 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"] 

2381 # Option 1. No known neighbors 

2382 if upstream == "TERMINAL" and downstream == "TERMINAL": 

2383 identifier = cluster 

2384 # Option 2. Upstream/downstream already has the notation 

2385 elif "__" in upstream or "__" in downstream: 

2386 identifier = cluster 

2387 # Option 3. At least one side is known 

2388 else: 

2389 identifier = f"{upstream}__{downstream}" 

2390 

2391 identifiers[cluster] = identifier 

2392 summarized[cluster]["cluster"] = identifier 

2393 

2394 # ------------------------------------------------------------------------- 

2395 # Pass #2: Finalize upstream/downstream 

2396 

2397 # Now that we know the identifiers, we need to update the  

2398 # following fields: cluster, upstream, and downstream 

2399 

2400 logging.info(f"Finalizing upstream/downstream identifiers.") 

2401 

2402 for cluster, cluster_data in tqdm(summarized.items()): 

2403 upstream, downstream = cluster_data["upstream"], cluster_data["downstream"] 

2404 if upstream in identifiers: 

2405 summarized[cluster]["upstream"] = upstream 

2406 if downstream in identifiers: 

2407 summarized[cluster]["downstream"] = downstream 

2408 

2409 # Update the keys in the graph 

2410 summarized = OrderedDict({v["cluster"]:v for k,v in summarized.items()}) 

2411 

2412 # ------------------------------------------------------------------------- 

2413 # Update synteny graph with new identifiers 

2414 # ------------------------------------------------------------------------- 

2415 

2416 logging.info(f"Updating cluster identifiers in the synteny graphs.") 

2417 

2418 # We will both update the original "full" graph, as well as our new 

2419 # "linear" graph 

2420 

2421 # ------------------------------------------------------------------------- 

2422 # Full Graph 

2423 

2424 networkx.relabel_nodes(synteny_full, mapping=identifiers) 

2425 edges = list(synteny_full.out_edges()) 

2426 for c1, c2 in edges: 

2427 synteny_full.remove_edge(c1, c2) 

2428 c1 = identifiers[c1] if c1 in identifiers else c1 

2429 c2 = identifiers[c2] if c2 in identifiers else c2 

2430 synteny_full.add_edge(c1, c2) 

2431 

2432 synteny_full_path = os.path.join(output_dir, f"{prefix}synteny.full.graphml") 

2433 logging.info(f"Writing full synteny GraphML: {synteny_full_path}") 

2434 networkx.write_graphml(synteny_full, synteny_full_path) 

2435 

2436 gfa_path = os.path.join(output_dir, f"{prefix}synteny.full.gfa") 

2437 logging.info(f"Writing full synteny GFA: {gfa_path}") 

2438 

2439 with open(gfa_path, 'w') as outfile: 

2440 outfile.write("H\tVN:Z:1.0\n") 

2441 for cluster in summarized: 

2442 outfile.write(f"S\t{cluster}\t*\tLN:i:1\n") 

2443 for c1, c2 in synteny_full.out_edges(): 

2444 c1_strand, c2_strand = summarized[c1]["strand"], summarized[c2]["strand"] 

2445 outfile.write(f"L\t{c1}\t{c1_strand}\t{c2}\t{c2_strand}\t0M\n") 

2446 

2447 # ------------------------------------------------------------------------- 

2448 # Linearized Graph 

2449 

2450 networkx.relabel_nodes(synteny_linear_directed, mapping=identifiers) 

2451 edges = list(synteny_linear_directed.out_edges()) 

2452 

2453 for c1, c2 in edges: 

2454 synteny_linear_directed.remove_edge(c1, c2) 

2455 c1 = identifiers[c1] if c1 in identifiers else c1 

2456 c2 = identifiers[c2] if c2 in identifiers else c2 

2457 synteny_linear_directed.add_edge(c1, c2) 

2458 

2459 synteny_linear_path = os.path.join(output_dir, f"{prefix}synteny.linear.graphml") 

2460 logging.info(f"Writing linear synteny GraphML: {synteny_linear_path}") 

2461 networkx.write_graphml(synteny_linear_directed, synteny_linear_path) 

2462 

2463 gfa_path = os.path.join(output_dir, f"{prefix}synteny.linear.gfa") 

2464 logging.info(f"Writing linear synteny GFA: {gfa_path}") 

2465 

2466 with open(gfa_path, 'w') as outfile: 

2467 outfile.write("H\tVN:Z:1.0\n") 

2468 for cluster in summarized: 

2469 outfile.write(f"S\t{cluster}\t*\tLN:i:1\n") 

2470 for c1, c2 in synteny_linear_directed.out_edges(): 

2471 c1_strand, c2_strand = summarized[c1]["strand"], summarized[c2]["strand"] 

2472 outfile.write(f"L\t{c1}\t{c1_strand}\t{c2}\t{c2_strand}\t0M\n") 

2473 

2474 # ------------------------------------------------------------------------- 

2475 # Write Output tsv 

2476 

2477 tsv_path = os.path.join(output_dir, f"{prefix}clusters.tsv") 

2478 logging.info(f"Writing summarized clusters tsv: {tsv_path}") 

2479 

2480 with open(tsv_path, 'w') as outfile: 

2481 header = None 

2482 for i, cluster in enumerate(tqdm(summarized)): 

2483 cluster_data = summarized[cluster] 

2484 if not header: 

2485 header = [k for k in cluster_data if k != "sequences" and k != "product_clean"] + all_samples 

2486 outfile.write("\t".join(header) + "\n") 

2487 row = [str(v) for k,v in cluster_data.items() if k != "sequences" and k != "product_clean"] 

2488 # Add info about which sequences map to each sample in the cluster 

2489 sample_to_seq_id = OrderedDict({sample:[] for sample in all_samples}) 

2490 for seq_id,seq_data in cluster_data["sequences"].items(): 

2491 sample = seq_data["sample"] 

2492 sample_to_seq_id[sample].append(seq_id) 

2493 for sample in all_samples: 

2494 row += [",".join(sample_to_seq_id[sample])] 

2495 

2496 outfile.write("\t".join(row) + "\n") 

2497 

2498 # ------------------------------------------------------------------------- 

2499 # Write table (for phandango) 

2500 

2501 phandango_path = os.path.join(output_dir, f"{prefix}phandango.csv") 

2502 logging.info(f"Writing table for phandango: {phandango_path}") 

2503 

2504 logging.info(f"Sorting synteny blocks according to number of samples: {phandango_path}") 

2505 syntenies = OrderedDict() 

2506 for cluster,c_data in summarized.items(): 

2507 synteny = c_data["synteny"] 

2508 if synteny not in syntenies: 

2509 syntenies[synteny] = {"max_samples": 0, "clusters": []} 

2510 num_samples = c_data["num_samples"] 

2511 syntenies[synteny]["max_samples"] = max(num_samples, syntenies[synteny]["max_samples"]) 

2512 syntenies[synteny]["clusters"].append(cluster) 

2513 

2514 syntenies = OrderedDict( 

2515 sorted( 

2516 syntenies.items(), 

2517 key=lambda item: item[1]["max_samples"], reverse=True 

2518 ) 

2519 ) 

2520 

2521 # This it the roary gene_presence_absence.csv format 

2522 with open(phandango_path, 'w') as outfile: 

2523 header = [ 

2524 "Gene","Non-unique Gene name","Annotation","No. isolates", 

2525 "No. sequences","Avg sequences per isolate","Genome Fragment", 

2526 "Order within Fragment","Accessory Fragment", 

2527 "Accessory Order with Fragment","QC" 

2528 ] + all_samples 

2529 outfile.write(",".join(header) + "\n") 

2530 for synteny,s_data in tqdm(syntenies.items()): 

2531 for cluster in s_data["clusters"]: 

2532 c_data = summarized[cluster] 

2533 data = OrderedDict({k:"" for k in header}) 

2534 # This is deliberately reversed 

2535 data["Gene"] = c_data["cluster"] 

2536 data["Non-unique Gene name"] = c_data["cluster_id"] 

2537 # We're going to use the cleaned product, because 

2538 # we need at least some bad characters to removed  

2539 # (like commas) 

2540 data["Annotation"] = c_data["product_clean"] 

2541 data["No. isolates"] = c_data["num_samples"] 

2542 data["No. sequences"] = cluster_data["num_sequences"] 

2543 data["Avg sequences per isolate"] = c_data["mean_sequences_per_sample"] 

2544 data["Genome Fragment"] = c_data["synteny"] 

2545 data["Order within Fragment"] = c_data["synteny_pos"] 

2546 # If a sample has a sequence then it will be given a "1" 

2547 # If not, is recorded as the empty string "" 

2548 # This is based on the file minimizing back from phandango: 

2549 # https://github.com/jameshadfield/phandango/blob/master/scripts/minimiseROARY.py  

2550 for s_data in c_data["sequences"].values(): 

2551 sample = s_data["sample"] 

2552 data[sample] = "1" 

2553 line = ",".join([str(v) for v in data.values()]) 

2554 outfile.write(line + "\n") 

2555 

2556 return tsv_path 

2557 

2558 

2559# Parallel 

2560def run_mafft(kwargs: dict): 

2561 """A wrapper function to pass multi-threaded pool args to mafft.""" 

2562 cmd, output, quiet, display_cmd = [kwargs[k] for k in ["cmd", "output", "quiet", "display_cmd"]] 

2563 run_cmd(cmd=cmd, output=output, quiet=quiet, display_cmd=display_cmd) 

2564 

2565def align( 

2566 clusters: str, 

2567 regions: str, 

2568 outdir: str = ".", 

2569 prefix: str = None, 

2570 exclude_singletons: bool = False, 

2571 threads: int = 1, 

2572 args: str = ALIGN_ARGS, 

2573 ): 

2574 """ 

2575 Align clusters using mafft and create a pangenome alignment. 

2576 

2577 Takes as input the summarized clusters from summarize and sequence regions from collect. 

2578 Outputs multiple sequence alignments per cluster as well as a pangenome alignment of  

2579 concatenated clusters. 

2580 

2581 >>> align(clusters="summarize.clusters.tsv", sequences="sequences.tsv") 

2582 >>> align(clusters="summarize.clusters.tsv", sequences="sequences.tsv", exclude_singletons=True, args="--localpair") 

2583 

2584 :param clusters: TSV file of clusters from summarize. 

2585 :param sequences: TSV file of sequence regions from collect. 

2586 :param outdir: Output directory. 

2587 :param prefix: Prefix for output files. 

2588 :param exclude_singletons: True is clusters found in only one sample should be excluded. 

2589 :param threads: Number of cpu threads to parallelize mafft across. 

2590 :param args: Additional arguments for MAFFT. 

2591 """ 

2592 

2593 from multiprocessing import get_context 

2594 

2595 clusters_path, sequences_path, output_dir = clusters, regions, outdir 

2596 

2597 # Check output directory 

2598 check_output_dir(output_dir) 

2599 

2600 args = args if args != None else "" 

2601 prefix = f"{prefix}." if prefix != None else "" 

2602 threads = int(threads) 

2603 

2604 # ------------------------------------------------------------------------- 

2605 # Read Sequence Regions 

2606 

2607 all_samples = [] 

2608 sequences = {} 

2609 

2610 logging.info(f"Reading sequence regions: {sequences_path}") 

2611 with open(sequences_path) as infile: 

2612 header = [line.strip() for line in infile.readline().split("\t")] 

2613 for line in infile: 

2614 row = [l.strip() for l in line.split("\t")] 

2615 data = {k:v for k,v in zip(header, row)} 

2616 sample, sequence_id = data["sample"], data["sequence_id"] 

2617 if sample not in all_samples: 

2618 all_samples.append(sample) 

2619 sequences[sequence_id] = data 

2620 

2621 # ------------------------------------------------------------------------- 

2622 # Read Summarized Clusters 

2623 

2624 logging.info(f"Reading summarized clusters: {clusters_path}") 

2625 clusters = OrderedDict() 

2626 all_samples = [] 

2627 with open(clusters_path) as infile: 

2628 header = [line.strip() for line in infile.readline().split("\t")] 

2629 # sample columns begin after dbxref_alt 

2630 all_samples = header[header.index("dbxref_alt")+1:] 

2631 lines = infile.readlines() 

2632 for line in tqdm(lines): 

2633 row = [l.strip() for l in line.split("\t")] 

2634 data = {k:v for k,v in zip(header, row)} 

2635 cluster = data["cluster"] 

2636 clusters[cluster] = data 

2637 clusters[cluster]["sequences"] = OrderedDict() 

2638 for sample in all_samples: 

2639 sequence_ids = data[sample].split(",") if data[sample] != "" else [] 

2640 # Add the sequence IDs, we will associate it with the actual 

2641 # sequence from the collect TSV 

2642 for sequence_id in sequence_ids: 

2643 clusters[cluster]["sequences"][sequence_id] = { 

2644 "sequence": sequences[sequence_id]["sequence"], 

2645 "sample": sample 

2646 } 

2647 

2648 # ------------------------------------------------------------------------- 

2649 # Write Sequences 

2650 

2651 representative_output = os.path.join(output_dir, prefix + "representative") 

2652 sequences_output = os.path.join(output_dir, prefix + "sequences") 

2653 alignments_output = os.path.join(output_dir, prefix + "alignments") 

2654 consensus_output = os.path.join(output_dir, prefix + "consensus") 

2655 

2656 check_output_dir(representative_output) 

2657 check_output_dir(sequences_output) 

2658 check_output_dir(alignments_output) 

2659 check_output_dir(consensus_output) 

2660 

2661 # ------------------------------------------------------------------------- 

2662 # Write Representative sequences for each cluster to file 

2663 

2664 skip_align = set() 

2665 clusters_exclude_singletons = OrderedDict() 

2666 

2667 logging.info(f"Writing representative sequences: {representative_output}") 

2668 for cluster,cluster_data in tqdm(clusters.items()): 

2669 rep_seq_id = cluster_data["representative"] 

2670 rep_seq = cluster_data["sequences"][rep_seq_id]["sequence"] 

2671 rep_path = os.path.join(representative_output, f"{cluster}.fasta") 

2672 samples = list(set([s["sample"] for s in cluster_data["sequences"].values()])) 

2673 

2674 # If only found in 1 sample, and the user requested to exclude these, skip 

2675 if len(samples) == 1 and exclude_singletons: 

2676 logging.debug(f"Skipping singleton cluster: {cluster}") 

2677 continue 

2678 

2679 # If this is cluster only has one sequence, write to final output 

2680 if len(cluster_data["sequences"]) == 1: 

2681 skip_align.add(cluster) 

2682 file_path = os.path.join(alignments_output, f"{cluster}.aln") 

2683 with open(file_path, 'w') as outfile: 

2684 outfile.write(f">{rep_seq_id}\n{rep_seq}\n") 

2685 # Otherwise, save rep seq to separate folder 

2686 # We might need it if user has requeseted the --add* mafft args 

2687 else: 

2688 with open(rep_path, 'w') as outfile: 

2689 outfile.write(f">{rep_seq_id}\n{rep_seq}\n") 

2690 

2691 clusters_exclude_singletons[cluster] = cluster_data 

2692 

2693 clusters = clusters_exclude_singletons 

2694 

2695 # ------------------------------------------------------------------------- 

2696 # Write DNA sequences for each cluster to file 

2697 

2698 # A queue of commands to submit to mafft in parallel  

2699 mafft_queue = [] 

2700 

2701 logging.info(f"Writing cluster sequences: {sequences_output}") 

2702 for i,cluster in enumerate(tqdm(clusters)): 

2703 # Skip singleton clusters, we already handled them in previous block 

2704 if cluster in skip_align: continue 

2705 cluster_data = clusters[cluster] 

2706 rep_seq_id = cluster_data["representative"] 

2707 representative_path = os.path.join(representative_output, f"{cluster}.fasta") 

2708 sequences_path = os.path.join(sequences_output, f"{cluster}.fasta") 

2709 with open(sequences_path, "w") as outfile: 

2710 for sequence_id,seq_data in cluster_data["sequences"].items(): 

2711 # Skip representative if we're using addfragments 

2712 if "--add" in args and sequence_id == rep_seq_id: 

2713 continue 

2714 sequence = seq_data["sequence"] 

2715 line = f">{sequence_id}\n{sequence}\n" 

2716 outfile.write(line) 

2717 alignment_path = os.path.join(alignments_output, f"{cluster}.aln") 

2718 cmd = f"mafft --thread 1 {args} {sequences_path}" 

2719 # If we're using addfragments, we're aligning against representative/reference 

2720 if "--add" in args: 

2721 cmd += f" {representative_path}" 

2722 mafft_queue.append({"cmd": cmd, "output": alignment_path, "quiet": True, "display_cmd": False}) 

2723 

2724 # Display first command 

2725 if len(mafft_queue) > 0: 

2726 logging.info(f"Command to run in parallel: {mafft_queue[0]['cmd']}") 

2727 

2728 # ------------------------------------------------------------------------- 

2729 # Align DNA sequences with MAFFT 

2730 

2731 # Parallel: MAFFT is very CPU-bound, so parallel is appropriate 

2732 logging.info(f"Aligning cluster sequences in parallel with {threads} threads: {alignments_output}") 

2733 with get_context('fork').Pool(threads) as p: 

2734 with tqdm(total=len(mafft_queue), unit="cluster") as bar: 

2735 for _ in p.imap_unordered(run_mafft, mafft_queue): 

2736 bar.update() 

2737 

2738 # ------------------------------------------------------------------------- 

2739 # Unwrap and dedup cluster alignments 

2740 

2741 # Duplicates occur due to multi-copy/fragments. We'll try to dedup the  

2742 # sequence by reconstructing consensus bases where possible. 

2743 

2744 logging.info(f"Unwrapping and dedupping alignments: {consensus_output}") 

2745 consensus_alignments = OrderedDict() 

2746 

2747 for cluster,cluster_data in tqdm(clusters.items()): 

2748 original_path = os.path.join(alignments_output, cluster + ".aln") 

2749 tmp_path = original_path + ".tmp" 

2750 consensus_path = os.path.join(consensus_output, cluster + ".aln") 

2751 

2752 # Unwrap sequences and convert to uppercase 

2753 alignment = {} 

2754 with open(original_path, 'r') as infile: 

2755 with open(tmp_path, 'w') as outfile: 

2756 records = infile.read().split(">")[1:] 

2757 for record in records: 

2758 record_split = record.split("\n") 

2759 sequence_id = record_split[0] 

2760 # Check for indicator that it was reverse complemented 

2761 if sequence_id not in cluster_data["sequences"] and sequence_id.startswith("_R_"): 

2762 sequence_id = sequence_id[3:] 

2763 

2764 sample = cluster_data["sequences"][sequence_id]["sample"] 

2765 sequence = "".join(record_split[1:]).replace("\n", "").upper() 

2766 if sample not in alignment: 

2767 alignment[sample] = [] 

2768 alignment[sample].append(sequence) 

2769 

2770 # Write the sequence to the output file, using the original header 

2771 outfile.write(f">{sequence_id}\n{sequence}\n") 

2772 

2773 # Replace the original file, now that unwrapped and uppercased 

2774 # We can use this in a different script for structural variant detection 

2775 shutil.move(tmp_path, original_path) 

2776 

2777 if len(alignment) == 0: 

2778 logging.info(f"WARNING: No sequences written for cluster: {cluster}") 

2779 continue 

2780 

2781 # Create consensus sequence from fragments and multi-copies 

2782 duplicate_samples = [sample for sample,seqs in alignment.items() if len(seqs) > 1] 

2783 # Create consensus alignment, first with the non-duplicated samples 

2784 alignment_consensus = {sample:seqs[0] for sample,seqs in alignment.items() if len(seqs) == 1} 

2785 

2786 for sample in duplicate_samples: 

2787 seqs = alignment[sample] 

2788 length = len(seqs[0]) 

2789 consensus = [] 

2790 

2791 # Reconstruct Consensus sequence 

2792 for i in range(0, length): 

2793 nuc_raw = set([s[i] for s in seqs]) 

2794 nuc = list(set([n for n in nuc_raw if n in NUCLEOTIDES] )) 

2795 if len(nuc) == 1: nuc_consensus = nuc[0] 

2796 elif nuc_raw == set("-"): nuc_consensus = "-" 

2797 else: nuc_consensus = "N" 

2798 consensus.append(nuc_consensus) 

2799 

2800 consensus_sequence = "".join(consensus) 

2801 alignment_consensus[sample] = consensus_sequence 

2802 

2803 consensus_alignments[cluster] = alignment_consensus 

2804 

2805 # Write unwrapped, dedupped, defragged, consensus alignment 

2806 with open(consensus_path, 'w') as outfile: 

2807 for (sample, sequence) in alignment_consensus.items(): 

2808 outfile.write(f">{sample}\n{sequence}\n") 

2809 

2810 # ------------------------------------------------------------------------- 

2811 # Concatenate into pangenome alignment 

2812 

2813 logging.info(f"Creating pangenome alignment.") 

2814 

2815 pangenome = { 

2816 "bed" : {}, 

2817 "alignment" : {s: [] for s in all_samples}, 

2818 } 

2819 

2820 curr_pos = 0 

2821 

2822 for cluster,alignment in tqdm(consensus_alignments.items()): 

2823 # identify samples missing from cluster 

2824 observed_samples = list(alignment.keys()) 

2825 missing_samples = [s for s in all_samples if s not in alignment] 

2826 

2827 # write gaps for missing samples 

2828 seq_len = len(alignment[observed_samples[0]]) 

2829 for sample in missing_samples: 

2830 alignment[sample] = "-" * seq_len 

2831 

2832 # concatenate cluster sequence to phylo alignment 

2833 for sample, seq in alignment.items(): 

2834 pangenome["alignment"][sample].append(seq) 

2835 

2836 # update bed coordinates 

2837 prev_pos = curr_pos 

2838 curr_pos = curr_pos + seq_len 

2839 

2840 pangenome["bed"][prev_pos] = { 

2841 "start" : prev_pos, 

2842 "end" : curr_pos, 

2843 "cluster" : cluster, 

2844 "synteny": clusters[cluster]["synteny"] 

2845 } 

2846 

2847 # Final concatenation of alignment 

2848 logging.info("Performing final concatenation.") 

2849 for sample in tqdm(all_samples): 

2850 pangenome["alignment"][sample] = "".join( pangenome["alignment"][sample]) 

2851 

2852 # ------------------------------------------------------------------------- 

2853 # Write pangenome Consensus Sequence 

2854 

2855 consensus_file_path = os.path.join(output_dir, prefix + "pangenome.consensus.fasta") 

2856 logging.info(f"Writing pangenome consensus: {consensus_file_path}") 

2857 with open(consensus_file_path, 'w') as out_file: 

2858 out_file.write(">consensus\n") 

2859 # get total length from first sample in alignment 

2860 first_sample = list(pangenome["alignment"].keys())[0] 

2861 consensus_len = len(pangenome["alignment"][first_sample]) 

2862 consensus = [] 

2863 logging.info(f"Consensus Length: {consensus_len}") 

2864 for cluster_start,bed_info in tqdm(pangenome["bed"].items()): 

2865 cluster_end, cluster = bed_info["end"], bed_info["cluster"] 

2866 for i in range(cluster_start, cluster_end): 

2867 nuc = [pangenome["alignment"][s][i] for s in all_samples] 

2868 nuc_non_ambig = [n for n in nuc if n in NUCLEOTIDES] 

2869 # All samples excluded due to ambiguity 

2870 if len(set(nuc_non_ambig)) == 0: 

2871 consensus_nuc = "N" 

2872 # Invariant position, all the same 

2873 elif len(set(nuc_non_ambig)) == 1: 

2874 consensus_nuc = nuc[0] 

2875 # Variant position, get consensus of nonambiguous nucleotides 

2876 else: 

2877 # choose which ever nucleotide has highest count as the consensus 

2878 counts = {n:nuc.count(n) for n in NUCLEOTIDES} 

2879 consensus_nuc = max(counts, key=counts.get) 

2880 consensus.append(consensus_nuc) 

2881 # write final end of line 

2882 out_file.write("".join(consensus) + "\n") 

2883 

2884 # ------------------------------------------------------------------------- 

2885 # Write Bed File 

2886 

2887 bed_file_path = os.path.join(output_dir, prefix + "pangenome.bed") 

2888 logging.info(f"Writing bed file: {bed_file_path}") 

2889 with open(bed_file_path, "w") as bed_file: 

2890 for start,info in tqdm(pangenome["bed"].items()): 

2891 cluster, synteny = info["cluster"], info["synteny"] 

2892 name = f"cluster={cluster};synteny={synteny}" 

2893 row = ["pangenome", start, info["end"], name] 

2894 line = "\t".join([str(val) for val in row]) 

2895 bed_file.write(line + "\n") 

2896 

2897 # ------------------------------------------------------------------------- 

2898 # Write Alignment 

2899 

2900 aln_file_path = os.path.join(output_dir, prefix + "pangenome.aln") 

2901 logging.info(f"Writing alignment file: {aln_file_path}") 

2902 with open(aln_file_path, "w") as aln_file: 

2903 for sample in tqdm(all_samples): 

2904 seq = pangenome["alignment"][sample].upper() 

2905 aln_file.write(">" + sample + "\n" + seq + "\n") 

2906 

2907 # ------------------------------------------------------------------------- 

2908 # Write GFF 

2909 

2910 gff_file_path = os.path.join(output_dir, prefix + "pangenome.gff3") 

2911 logging.info(f"Writing gff file: {gff_file_path}") 

2912 with open(gff_file_path, "w") as outfile: 

2913 outfile.write("##gff-version 3\n") 

2914 outfile.write(f"##sequence-region pangenome 1 {consensus_len}\n") 

2915 row = ["pangenome", ".", "region", 1, consensus_len, ".", "+", ".", "ID=pangenome;Name=pangenome"] 

2916 line = "\t".join([str(v) for v in row]) 

2917 outfile.write(line + "\n") 

2918 

2919 for start in pangenome["bed"]: 

2920 start, end, cluster, synteny = pangenome["bed"][start].values() 

2921 cluster_data = clusters[cluster] 

2922 representative, feature = cluster_data["representative"], cluster_data["feature"] 

2923 strand = sequences[representative]["strand"] 

2924 attributes = OrderedDict({"ID": cluster, "locus_tag": cluster}) 

2925 for col,k in zip(["gene", "name", "product", "dbxref"], ["gene", "Name", "product", "Dbxref"]): 

2926 if cluster_data[col] != "": 

2927 attributes[k] = cluster_data[col] 

2928 attributes[synteny] = synteny 

2929 attributes = ";".join([f"{k}={v}" for k,v in attributes.items()]) 

2930 row = ["pangenome", ".", feature, start, end, ".", strand, ".", attributes ] 

2931 line = "\t".join([str(v) for v in row]) 

2932 outfile.write(line + "\n") 

2933 

2934 

2935def get_ranges(numbers): 

2936 """ 

2937 Author: user97370, bossylobster 

2938 Source: https://stackoverflow.com/a/4629241 

2939 """ 

2940 for _a, b in itertools.groupby(enumerate(numbers), lambda pair: pair[1] - pair[0]): 

2941 b = list(b) 

2942 yield b[0][1], b[-1][1] 

2943 

2944 

2945def collapse_ranges(ranges, min_len: int, min_gap: int): 

2946 """Collapse a list of ranges into overlaps.""" 

2947 collapse = [] 

2948 curr_range = None 

2949 # Collapse based on min gap 

2950 for i, coord in enumerate(ranges): 

2951 start, stop = coord 

2952 if curr_range: 

2953 gap_size = start - curr_range[1] - 1 

2954 # Collapse tiny gap into previous fragment 

2955 if gap_size < min_gap: 

2956 curr_range = (curr_range[0], stop) 

2957 # Otherwise, start new range 

2958 else: 

2959 collapse.append(curr_range) 

2960 curr_range = (start, stop) 

2961 else: 

2962 curr_range = (start, stop) 

2963 # Last one 

2964 if curr_range not in collapse and i == len(ranges) - 1: 

2965 collapse.append(curr_range) 

2966 # Collapse based on min length size 

2967 collapse = [(start, stop) for start,stop in collapse if (stop - start + 1) >= min_len] 

2968 return collapse 

2969 

2970 

2971def structural( 

2972 clusters: str, 

2973 alignments: str, 

2974 outdir: str = ".", 

2975 prefix: str = None, 

2976 min_len: int = 10, 

2977 min_indel_len: int = 3, 

2978 args: str = None, 

2979 ): 

2980 """ 

2981 Extract structural variants from cluster alignments. 

2982 

2983 Takes as input the summarized clusters TSV and their individual alignments. 

2984 Outputs an Rtab file of structural variants. 

2985 

2986 >>> structural(clusters="summarize.clusters.tsv", alignments="alignments") 

2987 >>> structural(clusters="summarize.tsv", alignments="alignments", min_len=100, min_indel_len=10) 

2988 

2989 :param clusters: TSV file of clusters from summarize. 

2990 :param alignments: Directory of alignments from align. 

2991 :param outdir: Output directory. 

2992 :param prefix: Prefix for output files. 

2993 :param min_len: Minimum length of structural variants to extract. 

2994 :param min_indel_len: Minimum length of gaps that should separate variants. 

2995 :param args: Str of additional arguments [not implemented] 

2996 

2997 :return: Ordered Dictionary of structural variants. 

2998 """ 

2999 

3000 from Bio import SeqIO 

3001 

3002 clusters_path, alignments_dir, output_dir = clusters, alignments, outdir 

3003 args = args if args != None else "" 

3004 prefix = f"{prefix}." if prefix != None else "" 

3005 

3006 # Check output directory 

3007 check_output_dir(output_dir) 

3008 

3009 variants = OrderedDict() 

3010 

3011 # ------------------------------------------------------------------------- 

3012 # Read Clusters 

3013 

3014 logging.info(f"Reading summarized clusters: {clusters_path}") 

3015 all_samples = [] 

3016 

3017 with open(clusters_path) as infile: 

3018 header = [line.strip() for line in infile.readline().split("\t")] 

3019 # sample columns begin after dbxref_alt 

3020 all_samples = header[header.index("dbxref_alt")+1:] 

3021 lines = infile.readlines() 

3022 for line in tqdm(lines): 

3023 row = [l.strip() for l in line.split("\t")] 

3024 data = {k:v for k,v in zip(header, row)} 

3025 cluster = data["cluster"] 

3026 

3027 # Get mapping of sequence IDs to samples 

3028 seq_to_sample = {} 

3029 for sample in all_samples: 

3030 sequence_ids = data[sample].split(",") if data[sample] != "" else [] 

3031 for sequence_id in sequence_ids: 

3032 seq_to_sample[sequence_id] = sample 

3033 

3034 # Read cluster alignment from the file 

3035 alignment_path = os.path.join(alignments_dir, f"{cluster}.aln") 

3036 if not os.path.exists(alignment_path): 

3037 # Alignment might not exist if user excluded singletons in align 

3038 num_samples = len(list(set(seq_to_sample.values()))) 

3039 if num_samples == 1: 

3040 logging.debug(f"Singleton {cluster} alignment was not found: {alignment_path}") 

3041 continue 

3042 # otherwise, this alignment should exist, stop here 

3043 else: 

3044 msg = f"{cluster} alignment was not found: {alignment_path}" 

3045 logging.error(msg) 

3046 raise Exception(msg) 

3047 

3048 alignment = OrderedDict() 

3049 

3050 for record in SeqIO.parse(alignment_path, "fasta"): 

3051 sample = seq_to_sample[record.id] 

3052 if sample not in alignment: 

3053 alignment[sample] = OrderedDict() 

3054 alignment[sample][record.id] = record.seq 

3055 

3056 # Parse out the fragment pieces based on the pattern of '-' 

3057 cluster_variants = OrderedDict() 

3058 

3059 present_samples = set() 

3060 for sample in alignment: 

3061 sample_frag_ranges = [] 

3062 for seq in alignment[sample].values(): 

3063 non_ambig_nuc = set([nuc for nuc in seq if nuc in NUCLEOTIDES]) 

3064 if len(non_ambig_nuc) > 0: 

3065 present_samples.add(sample) 

3066 frag_ranges = [(start+1, stop+1) for start,stop in list(get_ranges([i for i,nuc in enumerate(seq) if nuc != "-"]))] 

3067 frag_ranges_collapse = collapse_ranges(ranges=frag_ranges, min_len=min_len, min_gap=min_indel_len) 

3068 sample_frag_ranges.append(frag_ranges_collapse) 

3069 

3070 # Sort them for output consistency 

3071 sample_frag_ranges = sorted(sample_frag_ranges) 

3072 

3073 for frag in sample_frag_ranges: 

3074 # Simple presence  

3075 frag_text = f"{cluster}|structural:" + "_".join([f"{start}-{stop}" for start,stop in frag]) 

3076 if frag_text not in cluster_variants: 

3077 cluster_variants[frag_text] = [] 

3078 if sample not in cluster_variants[frag_text]: 

3079 cluster_variants[frag_text].append(sample) 

3080 # Copy number (ex. 2X) 

3081 count = sample_frag_ranges.count(frag) 

3082 if count > 1: 

3083 copy_number_text = f"{frag_text}|{count}X" 

3084 if copy_number_text not in cluster_variants: 

3085 cluster_variants[copy_number_text] = [] 

3086 if sample not in cluster_variants[copy_number_text]: 

3087 cluster_variants[copy_number_text].append(sample) 

3088 

3089 # If there is only one structure, it's not variant 

3090 if len(cluster_variants) <= 1: continue 

3091 # Samples that are truly missing will be given a "." 

3092 missing_samples = [s for s in all_samples if s not in present_samples] 

3093 for variant,samples in cluster_variants.items(): 

3094 if len(samples) == len(all_samples): continue 

3095 variants[variant] = {"present": samples, "missing": missing_samples} 

3096 

3097 # ------------------------------------------------------------------------- 

3098 # Write Structural Variants Rtab 

3099 

3100 rtab_path = os.path.join(output_dir, f"{prefix}structural.Rtab") 

3101 logging.info(f"Writing variants: {rtab_path}") 

3102 

3103 with open(rtab_path, 'w') as outfile: 

3104 header = ["Variant"] + all_samples 

3105 outfile.write("\t".join(header) + "\n") 

3106 for variant,data in tqdm(variants.items()): 

3107 observations = ["1" if s in data["present"] else "." if s in data["missing"] else "0" for s in all_samples] 

3108 line = "\t".join([variant] + observations) 

3109 outfile.write(line + "\n") 

3110 

3111 return rtab_path 

3112 

3113 

3114def snps( 

3115 alignment: str, 

3116 bed: str, 

3117 consensus: str, 

3118 outdir: str = ".", 

3119 prefix: str = None, 

3120 structural: str = None, 

3121 core: float = 0.95, 

3122 indel_window: int = 0, 

3123 snp_window: int = 0, 

3124 args: str = None, 

3125 ): 

3126 """ 

3127 Extract SNPs from a pangenome alignment. 

3128 

3129 Takes as input the pangenome alignment fasta, bed, and consensus file. 

3130 Outputs an Rtab file of SNPs. 

3131 

3132 >>> snps(alignment="pangenome.aln", bed="pangenome.bed", consensus="pangenome.consensus.fasta") 

3133 >>> snps(alignment="pangenome.aln", bed="pangenome.bed", consensus="pangenome.consensus.fasta", structural="structural.Rtab", indel_window=10, snp_window=3) 

3134 

3135 :param alignment: FASTA file of the pangenome alignment from align. 

3136 :param bed: BED file of the pangenome coordinates from align. 

3137 :param consensus: FASTA file of the pangenome consensus from align. 

3138 :param outdir: Output directory. 

3139 :param prefix: Prefix for output files. 

3140 :param structural: Rtab file of structural variants from structural. 

3141 :param core: Core genome threshold for calling core SNPs. 

3142 :param indel_window: Exclude SNPs that are within this proximity to indels. 

3143 :param snp_window: Exclude SNPs that are within this proximity to another SNP. 

3144 :param args: Str of additional arguments [not implemented] 

3145 """ 

3146 

3147 alignment_path, bed_path, consensus_path, structural_path, output_dir = alignment, bed, consensus, structural, outdir 

3148 args = args if args != None else "" 

3149 prefix = f"{prefix}." if prefix != None else "" 

3150 

3151 # Check output directory 

3152 check_output_dir(output_dir) 

3153 

3154 # ------------------------------------------------------------------------- 

3155 # Read Pangenome Bed 

3156 

3157 clusters = {} 

3158 logging.info(f"Reading bed: {bed_path}") 

3159 with open(bed_path) as infile: 

3160 lines = infile.readlines() 

3161 for line in tqdm(lines): 

3162 line = [l.strip() for l in line.split("\t")] 

3163 start, end, name = (int(line[1]), int(line[2]), line[3]) 

3164 info = {n.split("=")[0]:n.split("=")[1] for n in name.split(";")} 

3165 cluster = info["cluster"] 

3166 synteny = info["synteny"] 

3167 clusters[cluster] = {"start": start, "end": end, "synteny": synteny } 

3168 

3169 # ------------------------------------------------------------------------- 

3170 # Read Pangenome Alignment 

3171 

3172 alignment = {} 

3173 all_samples = [] 

3174 logging.info(f"Reading alignment: {alignment_path}") 

3175 with open(alignment_path) as infile: 

3176 records = infile.read().split(">")[1:] 

3177 for record in tqdm(records): 

3178 record_split = record.split("\n") 

3179 sample = record_split[0] 

3180 if sample not in all_samples: 

3181 all_samples.append(sample) 

3182 sequence = "".join(record_split[1:]).replace("\n", "").upper() 

3183 alignment[sample] = sequence 

3184 

3185 # ------------------------------------------------------------------------- 

3186 # Read Consensus Sequence 

3187 

3188 representative = "" 

3189 logging.info(f"Reading consensus sequence: {consensus_path}") 

3190 with open(consensus_path) as infile: 

3191 record = infile.read().split(">")[1:][0] 

3192 record_split = record.split("\n") 

3193 representative = "".join(record_split[1:]).replace("\n", "").upper() 

3194 

3195 alignment_len = len(alignment[all_samples[0]]) 

3196 

3197 # ------------------------------------------------------------------------- 

3198 # Read Optional Structural Rtab 

3199 

3200 # Use the structural variants to locate the terminal ends of sequences 

3201 

3202 structural = OrderedDict() 

3203 if structural_path: 

3204 logging.info(f"Reading structural: {structural_path}") 

3205 with open(structural_path) as infile: 

3206 header = infile.readline().strip().split("\t") 

3207 samples = header[1:] 

3208 lines = infile.readlines() 

3209 for line in tqdm(lines): 

3210 row = [v.strip() for v in line.split("\t")] 

3211 variant = row[0].split("|") 

3212 cluster = variant[0] 

3213 if cluster not in structural: 

3214 structural[cluster] = OrderedDict() 

3215 # Cluster_1|structural|1-195_202-240 --> Terminal=(1,240) 

3216 coords = variant[1].split(":")[1].split("_") 

3217 start = int(coords[0].split("-")[0]) 

3218 end = int(coords[len(coords) - 1].split("-")[1]) 

3219 for sample,observation in zip(samples, row[1:]): 

3220 if observation == "1": 

3221 if sample not in structural[cluster]: 

3222 structural[cluster][sample] = [] 

3223 structural[cluster][sample].append((start, end)) 

3224 

3225 # ------------------------------------------------------------------------- 

3226 # Extract SNPs from Alignment 

3227 # ------------------------------------------------------------------------- 

3228 

3229 logging.info("Extracting SNPs.") 

3230 

3231 constant_sites = {n:0 for n in NUCLEOTIDES} 

3232 

3233 snps_data = OrderedDict() 

3234 

3235 # Iterate over alignment according to cluster 

3236 for cluster,cluster_data in tqdm(clusters.items()): 

3237 synteny = cluster_data["synteny"] 

3238 cluster_start, cluster_stop = cluster_data["start"], cluster_data["end"] 

3239 for i in range(cluster_start - 1, cluster_stop): 

3240 # Extract nucleotides for all samples 

3241 nuc = {s:alignment[s][i] for s in all_samples} 

3242 nuc_non_ambig = [n for n in nuc.values() if n in NUCLEOTIDES] 

3243 nuc_non_ambig_set = set(nuc_non_ambig) 

3244 prop_non_ambig = len(nuc_non_ambig) / len(nuc) 

3245 

3246 # Option #1. If all missing/ambiguous, skip over  

3247 if len(nuc_non_ambig_set) == 0: 

3248 continue 

3249 # Option #2. All the same/invariant 

3250 elif len(nuc_non_ambig_set) == 1: 

3251 # record for constant sites if it's a core site 

3252 if prop_non_ambig >= core: 

3253 n = list(nuc_non_ambig)[0] 

3254 constant_sites[n] += 1 

3255 continue 

3256 # Option #3. Variant, process it further 

3257 # Make snp positions 1-based (like VCF) 

3258 pangenome_pos = i + 1 

3259 cluster_pos = pangenome_pos - cluster_start 

3260 # Treat the representative nucleotide as the reference 

3261 ref = representative[i] 

3262 alt = [] 

3263 genotypes = [] 

3264 sample_genotypes = {} 

3265 

3266 # Filter on multi-allelic and indel proximity 

3267 for s,n in nuc.items(): 

3268 

3269 # ------------------------------------------------------------- 

3270 # Indel proximity checking 

3271 

3272 if indel_window > 0: 

3273 # Check window around SNP for indels 

3274 upstream_i = i - indel_window 

3275 if upstream_i < cluster_start: 

3276 upstream_i = cluster_start 

3277 downstream_i = i + indel_window 

3278 if downstream_i > (cluster_stop - 1): 

3279 downstream_i = (cluster_stop - 1) 

3280 

3281 # Check if we should truncate the downstream/upstream i by terminals 

3282 if cluster in structural and s in structural[cluster]: 

3283 for start,stop in structural[cluster][s]: 

3284 # Adjust start,stop coordinates from cluster to whole genome 

3285 start_i, stop_i = (cluster_start + start) - 1, (cluster_start + stop) - 1 

3286 # Original: 226-232, Terminal: 0,230 --> 226-230 

3287 if stop_i > upstream_i and stop_i < downstream_i: 

3288 downstream_i = stop_i 

3289 # Original: 302-308, Terminal: 303,389 --> 303-308 

3290 if start_i > upstream_i and start_i < downstream_i: 

3291 upstream_i = start_i 

3292 

3293 context = alignment[s][upstream_i:downstream_i + 1] 

3294 

3295 # If indel is found nearby, mark this as missing/ambiguous 

3296 if "-" in context: 

3297 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) in {s} was filtered out due to indel proximity: {context}") 

3298 genotypes.append(".") 

3299 nuc[s] = "." 

3300 sample_genotypes[s] = "./." 

3301 continue 

3302 

3303 # -------------------------------------------------------------  

3304 # sample nuc is different than ref 

3305 if n != ref: 

3306 # handle ambiguous/missing 

3307 if n not in NUCLEOTIDES: 

3308 genotypes.append(".") 

3309 nuc[s] = "." 

3310 sample_genotypes[s] = "./." 

3311 continue 

3312 

3313 # add this as a new alt genotype 

3314 if n not in alt: 

3315 alt.append(n) 

3316 sample_genotypes[s] = "1/1" 

3317 else: 

3318 sample_genotypes[s] = "0/0" 

3319 

3320 genotypes.append(([ref] + alt).index(n)) 

3321 

3322 # Update our non-ambiguous nucleotides 

3323 nuc_non_ambig = [n for n in nuc.values() if n in NUCLEOTIDES] 

3324 nuc_non_ambig_set = set(nuc_non_ambig) 

3325 

3326 genotypes_non_ambig = [g for g in genotypes if g != "."] 

3327 

3328 # Check if it's still a variant position after indel filtering 

3329 if len(alt) == 0 or len(genotypes_non_ambig) == 1: 

3330 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) was filtered out as mono-allelic: {ref}") 

3331 constant_sites[ref] += 1 

3332 continue 

3333 # If more than 1 ALT alleles, this is non-biallelic (multi-allelic) so we'll filter it out 

3334 elif len(alt) > 1: 

3335 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) was filtered out as multi-allelic: {','.join([ref] + alt)}") 

3336 continue 

3337 

3338 alt = list(alt)[0] 

3339 # Use "." to indicate missing values (ex. Rtab format) 

3340 observations = ["1" if n == alt else "." if n != ref else "0" for n in nuc.values()] 

3341 # One more check if it's no longer variant... 

3342 observations_non_ambig = set([o for o in observations if o == "1" or o == "0"]) 

3343 if len(observations_non_ambig) <= 1: 

3344 observations_nuc = alt[0] if "1" in observations_non_ambig else ref if "0" in observations_non_ambig else "N" 

3345 logging.debug(f"{cluster} {ref}{cluster_pos}{n} ({ref}{pangenome_pos}{n}) was filtered out as mono-allelic: {observations_nuc}") 

3346 continue 

3347 

3348 # This is a valid SNP! Calculate extra stats 

3349 allele_frequency = len([n for n in nuc_non_ambig if n != ref]) / len(nuc) 

3350 

3351 # Update data 

3352 pangenome_snp = f"{ref}{pangenome_pos}{alt}" 

3353 cluster_snp = f"{cluster}|snp:{ref}{cluster_pos}{alt}" 

3354 snps_data[cluster_snp] = OrderedDict({ 

3355 "pangenome_snp": pangenome_snp, 

3356 "prop_non_ambig" : prop_non_ambig, 

3357 "allele_frequency": allele_frequency, 

3358 "cluster" : cluster, 

3359 "synteny": synteny, 

3360 "observations": observations, 

3361 "pangenome_pos": pangenome_pos, 

3362 "cluster_pos": cluster_pos, 

3363 "nuc": nuc, 

3364 "ref": ref, 

3365 "alt": alt, 

3366 "sample_genotypes": sample_genotypes, 

3367 }) 

3368 

3369 if snp_window > 0: 

3370 snps_exclude = set() 

3371 logging.info(f"Filtering out SNPs within {snp_window} bp of each other.") 

3372 snps_order = list(snps_data.keys()) 

3373 for i,snp in enumerate(snps_order): 

3374 cluster, coord = snps_data[snp]["cluster"], snps_data[snp]["cluster_pos"] 

3375 if i > 0: 

3376 prev_snp = snps_order[i-1] 

3377 prev_cluster, prev_coord = snps_data[prev_snp]["cluster"], snps_data[prev_snp]["cluster_pos"], 

3378 if (cluster == prev_cluster) and (coord - prev_coord <= snp_window): 

3379 snps_exclude.add(snp) 

3380 snps_exclude.add(prev_snp) 

3381 logging.debug(f"SNP {snp} and {prev_snp} are filtered out due to proximity <= {snp_window}.") 

3382 snps_data = OrderedDict({snp:data for snp,data in snps_data.items() if snp not in snps_exclude}) 

3383 

3384 # ------------------------------------------------------------------------- 

3385 # Prepare Outputs 

3386 

3387 snp_all_alignment = { s:[] for s in all_samples} 

3388 snp_core_alignment = { s:[] for s in all_samples} 

3389 snp_all_table = open(os.path.join(output_dir, f"{prefix}snps.all.tsv"), 'w') 

3390 snp_core_table = open(os.path.join(output_dir, f"{prefix}snps.core.tsv"), 'w') 

3391 snp_all_vcf = open(os.path.join(output_dir, f"{prefix}snps.all.vcf"), 'w') 

3392 snp_core_vcf = open(os.path.join(output_dir, f"{prefix}snps.core.vcf"), 'w') 

3393 snp_rtab_path = os.path.join(output_dir, f"{prefix}snps.Rtab") 

3394 snp_rtab = open(snp_rtab_path, 'w') 

3395 

3396 # TSV Table Header 

3397 header = ["snp", "pangenome_snp", "prop_non_ambig", "allele_frequency", "cluster"] + all_samples 

3398 snp_all_table.write("\t".join(header) + "\n") 

3399 snp_core_table.write("\t".join(header) + "\n") 

3400 

3401 # Rtab Header 

3402 header = ["Variant"] + all_samples 

3403 snp_rtab.write("\t".join(header) + "\n") 

3404 

3405 # VCF Header 

3406 all_samples_header = "\t".join(all_samples) 

3407 header = textwrap.dedent( 

3408 f"""\ 

3409 ##fileformat=VCFv4.2 

3410 ##contig=<ID=pangenome,length={alignment_len}> 

3411 ##INFO=<ID=CR,Number=0,Type=Flag,Description="Consensus reference allele, not based on real reference genome"> 

3412 ##INFO=<ID=C,Number=1,Type=String,Description="Cluster"> 

3413 ##INFO=<ID=CP,Number=1,Type=Int,Description="Cluster Position"> 

3414 ##INFO=<ID=S,Number=1,Type=String,Description="Synteny"> 

3415 ##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype"> 

3416 ##FILTER=<ID=IW,Description="Indel window proximity filter {indel_window}"> 

3417 ##FILTER=<ID=SW,Description="SNP window proximity filter {snp_window}"> 

3418 #CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t{all_samples_header} 

3419 """ 

3420 ) 

3421 snp_all_vcf.write(header) 

3422 snp_core_vcf.write(header) 

3423 

3424 logging.info(f"Finalizing all and core SNPs ({core}).") 

3425 for snp,d in snps_data.items(): 

3426 # Table 

3427 table_row = [snp] + [d["pangenome_snp"], d["prop_non_ambig"], d["allele_frequency"], d["cluster"]] + d["observations"] 

3428 snp_all_table.write("\t".join(str(v) for v in table_row) + "\n") 

3429 # Alignment 

3430 for s in all_samples: 

3431 snp_all_alignment[s].append(d["nuc"][s]) 

3432 # Rtab 

3433 rtab_row = [snp] + d["observations"] 

3434 snp_rtab.write("\t".join(rtab_row) + "\n") 

3435 # VCF 

3436 pangenome_pos, cluster_pos = d["pangenome_pos"], d["cluster_pos"] 

3437 cluster, synteny = d["cluster"], d["synteny"] 

3438 ref, alt = d["ref"], d["alt"] 

3439 info = f"CR;C={cluster};CP={cluster_pos},S={synteny}" 

3440 genotypes = [] 

3441 for sample in all_samples: 

3442 genotypes.append(d["sample_genotypes"][sample]) 

3443 vcf_row = ["pangenome", pangenome_pos, snp, ref, alt, ".", "PASS", info, "GT"] + genotypes 

3444 vcf_line = "\t".join([str(v) for v in vcf_row]) 

3445 snp_all_vcf.write(vcf_line + "\n") 

3446 

3447 # core SNPs 

3448 if d["prop_non_ambig"] >= core: 

3449 # Table 

3450 snp_core_table.write("\t".join(str(v) for v in table_row) + "\n") 

3451 for s in all_samples: 

3452 snp_core_alignment[s].append(d["nuc"][s]) 

3453 # VCF 

3454 snp_core_vcf.write(vcf_line + "\n") 

3455 

3456 # ------------------------------------------------------------------------- 

3457 # Write SNP fasta alignments 

3458 

3459 snp_all_alignment_path = os.path.join(output_dir, f"{prefix}snps.all.fasta") 

3460 logging.info(f"Writing all SNP alignment: {snp_all_alignment_path}") 

3461 with open(snp_all_alignment_path, 'w') as outfile: 

3462 for sample in tqdm(all_samples): 

3463 sequence = "".join(snp_all_alignment[sample]) 

3464 outfile.write(f">{sample}\n{sequence}\n") 

3465 

3466 snp_core_alignment_path = os.path.join(output_dir, f"{prefix}snps.core.fasta") 

3467 logging.info(f"Writing {core} core SNP alignment: {snp_core_alignment_path}") 

3468 with open(snp_core_alignment_path, 'w') as outfile: 

3469 for sample in tqdm(all_samples): 

3470 sequence = "".join(snp_core_alignment[sample]) 

3471 outfile.write(f">{sample}\n{sequence}\n") 

3472 

3473 # ------------------------------------------------------------------------- 

3474 # Write constant sites 

3475 

3476 constant_sites_path = os.path.join(output_dir, f"{prefix}snps.constant_sites.txt") 

3477 logging.info(f"Writing constant sites: {constant_sites_path}") 

3478 with open(constant_sites_path, 'w') as outfile: 

3479 line = ",".join([str(v) for v in list(constant_sites.values())]) 

3480 outfile.write(line + "\n") 

3481 

3482 # ------------------------------------------------------------------------- 

3483 # Cleanup 

3484 

3485 snp_all_table.close() 

3486 snp_core_table.close() 

3487 snp_rtab.close() 

3488 snp_all_vcf.close() 

3489 snp_core_vcf.close() 

3490 

3491 return snp_rtab_path 

3492 

3493 

3494def presence_absence( 

3495 clusters: str, 

3496 outdir: str = ".", 

3497 prefix: str = None, 

3498 args: str = None 

3499 ): 

3500 """ 

3501 Extract presence absence of summarized clusters. 

3502 

3503 Takes as input the TSV file of summarized clusters from summarize. 

3504 Outputs an Rtab file of cluster presence/absence. 

3505 

3506 Examples: 

3507 >>> presence_absence(clusters="summarize.clusters.tsv") 

3508 

3509 :param clusters: Path to TSV of summarized clusters from summarize. 

3510 :param outdir: Output directory. 

3511 :param prefix: Prefix for output files. 

3512 :param args: Str of addtional arguments [not implemented]. 

3513 """ 

3514 

3515 clusters_path, output_dir = clusters, outdir 

3516 args = args if args != None else "" 

3517 prefix = f"{prefix}." if prefix != None else "" 

3518 

3519 # Check output directory 

3520 check_output_dir(output_dir) 

3521 

3522 all_samples = [] 

3523 

3524 # ------------------------------------------------------------------------- 

3525 # Read Clusters 

3526 

3527 logging.info(f"Reading summarized clusters: {clusters_path}") 

3528 variants = OrderedDict() 

3529 

3530 with open(clusters_path) as infile: 

3531 header = [line.strip() for line in infile.readline().split("\t")] 

3532 # sample columns begin after dbxref_alt 

3533 all_samples = header[header.index("dbxref_alt")+1:] 

3534 lines = infile.readlines() 

3535 for line in tqdm(lines): 

3536 row = [l.strip() for l in line.split("\t")] 

3537 data = {k:v for k,v in zip(header, row)} 

3538 cluster = data["cluster"] 

3539 variants[cluster] = [] 

3540 

3541 # Get samples with sequences for this cluster 

3542 for sample in all_samples: 

3543 sequence_ids = data[sample].split(",") if data[sample] != "" else [] 

3544 if len(sequence_ids) > 0: 

3545 variants[cluster].append(sample) 

3546 # If the cluster is all samples, it's not "variable" 

3547 if len(variants[cluster]) == len(all_samples): 

3548 del variants[cluster] 

3549 

3550 # ------------------------------------------------------------------------- 

3551 # Write Presence Absence Rtab 

3552 

3553 rtab_path = os.path.join(output_dir, f"{prefix}presence_absence.Rtab") 

3554 logging.info(f"Writing variants: {rtab_path}") 

3555 

3556 with open(rtab_path, 'w') as outfile: 

3557 header = ["Variant"] + all_samples 

3558 outfile.write("\t".join(header) + "\n") 

3559 for variant,samples in tqdm(variants.items()): 

3560 variant = f"{variant}|presence_absence" 

3561 observations = ["1" if s in samples else "0" for s in all_samples] 

3562 line = "\t".join([variant] + observations) 

3563 outfile.write(line + "\n") 

3564 

3565 return rtab_path 

3566 

3567 

3568def combine( 

3569 rtab: list, 

3570 outdir: str = ".", 

3571 prefix: str = None, 

3572 args: str = None, 

3573 ): 

3574 """ 

3575 Combine variants from multiple Rtab files. 

3576 

3577 Takes as input a list of file paths to Rtab files. Outputs an Rtab file with 

3578 the variants concatenated, ensuring consistent ordering of the sample columns. 

3579 

3580 >>> combine(rtab=["snps.Rtab", "structural.Rtab", "presence_absence.Rtab"]) 

3581 

3582 :param rtab: List of paths to input Rtab files to combine. 

3583 :param outdir: Output directory. 

3584 :param prefix: Prefix for output files. 

3585 :param args: Str of additional arguments [not implemented]. 

3586 """ 

3587 

3588 output_dir = outdir 

3589 prefix = f"{prefix}." if prefix != None else "" 

3590 # Check output directory 

3591 check_output_dir(output_dir) 

3592 

3593 all_samples = [] 

3594 rtab_path = os.path.join(output_dir, f"{prefix}combine.Rtab") 

3595 logging.info(f"Writing combined Rtab: {rtab_path}") 

3596 with open(rtab_path, 'w') as outfile: 

3597 for file_path in rtab: 

3598 if not file_path: continue 

3599 logging.info(f"Reading variants: {file_path}") 

3600 with open(file_path) as infile: 

3601 header = infile.readline().strip().split("\t") 

3602 samples = header[1:] 

3603 if len(all_samples) == 0: 

3604 all_samples = samples 

3605 outfile.write("\t".join(["Variant"] + all_samples) + "\n") 

3606 # Organize the sample order consistently 

3607 lines = infile.readlines() 

3608 for line in tqdm(lines): 

3609 row = [v.strip() for v in line.split("\t")] 

3610 variant = row[0] 

3611 observations = {k:v for k,v in zip(samples, row[1:])} 

3612 out_row = [variant] + [observations[s] for s in all_samples] 

3613 outfile.write("\t".join(out_row) + "\n") 

3614 

3615 return rtab_path 

3616 

3617 

3618def root_tree( 

3619 tree: str, 

3620 outgroup: list = [], 

3621 tree_format: str = "newick", 

3622 outdir: str = ".", 

3623 prefix: str = None, 

3624 args: str = None, 

3625 ): 

3626 """ 

3627 Root a tree on outgroup taxa. 

3628 """ 

3629 

3630 tree_path, output_dir = tree, outdir 

3631 prefix = f"{prefix}." if prefix != None else "" 

3632 # Check output directory 

3633 check_output_dir(output_dir) 

3634 

3635 if type(outgroup) == str: 

3636 outgroup = outgroup.split(",") 

3637 elif type(outgroup) != list: 

3638 msg = f"The outgroup must be either a list or CSV string: {outgroup}" 

3639 logging.error(msg) 

3640 raise Exception(msg) 

3641 

3642 # Prioritize tree path 

3643 logging.info(f"Reading tree: {tree_path}") 

3644 tree = read_tree(tree_path, tree_format=tree_format) 

3645 tree.is_rooted = True 

3646 # Cleanup node/taxon labels 

3647 for node in list(tree.preorder_node_iter()): 

3648 if node.taxon: 

3649 node.label = str(node.taxon) 

3650 node.label = node.label.replace("\"", "").replace("'", "") if node.label != None else None 

3651 

3652 logging.info(f"Rooting tree with outgroup: {outgroup}") 

3653 

3654 tree_tips = [n.label for n in tree.leaf_node_iter()] 

3655 if sorted(tree_tips) == sorted(outgroup): 

3656 msg = "Your outgroup contains all taxa in the tree." 

3657 logging.error(msg) 

3658 raise Exception(msg) 

3659 

3660 # sort the outgroup labels for easier checks later 

3661 outgroup = sorted(outgroup) 

3662 

3663 # Case 1: No outgroup requested, use first tip 

3664 if len(outgroup) == 0: 

3665 outgroup_node = [n for n in tree.leaf_node_iter()][0] 

3666 logging.info(f"No outgroup requested, rooting on first tip: {outgroup_node.label}") 

3667 edge_length = outgroup_node.edge_length / 2 

3668 # Case 2: 1 or more outgroups requested 

3669 else: 

3670 # Case 2a: find a clade that is only the outgroup 

3671 logging.info(f"Searching for outgroup clade: {outgroup}") 

3672 outgroup_node = None 

3673 for node in tree.nodes(): 

3674 node_tips = sorted([n.label for n in node.leaf_iter()]) 

3675 if node_tips == outgroup: 

3676 logging.info(f"Outgroup clade found.") 

3677 outgroup_node = node 

3678 break 

3679 # Case 2b: find a clade that is only the ingroup 

3680 if outgroup_node == None: 

3681 logging.info(f"Searching for ingroup clade that does not include: {outgroup}") 

3682 for node in tree.nodes(): 

3683 node_tips = sorted([n.label for n in node.leaf_iter() if n.label in outgroup]) 

3684 if len(node_tips) == 0: 

3685 # If this is a taxon, make sure the outgroup is complete 

3686 if node.taxon != None: 

3687 remaining_taxa = sorted([t for t in tree_tips if t != node.label]) 

3688 if remaining_taxa != outgroup: 

3689 continue 

3690 logging.info(f"Ingroup clade found.") 

3691 outgroup_node = node 

3692 break 

3693 # If we couldn't find an outgroup, it's probably not monophyletic 

3694 # If we created the tree with the same outgroup in IQTREE, this  

3695 # shouldn't be possible, but we'll handle just in case 

3696 if outgroup_node == None: 

3697 msg = "Failed to find the outgroup clade. Are you sure your outgroup is monophyletic?" 

3698 logging.error(msg) 

3699 raise Exception(msg) 

3700 edge_length = outgroup_node.edge_length / 2 

3701 

3702 tree.reroot_at_edge(outgroup_node.edge, update_bipartitions=True, length1=edge_length, length2=edge_length ) 

3703 

3704 rooted_path = os.path.join(output_dir, f"{prefix}rooted.{tree_format}") 

3705 logging.info(f"Writing rooted tree: {rooted_path}") 

3706 # Suppress the [&R] prefix, which causes problems in downstream programs  

3707 tree.write(path=rooted_path, schema=tree_format, suppress_rooting=True) 

3708 

3709 return rooted_path 

3710 

3711def root_tree_midpoint(tree): 

3712 """ 

3713 Reroots the tree at the midpoint of the longest distance between 

3714 two taxa in a tree. 

3715 

3716 This is a modification of dendropy's reroot_at_midpoint function: 

3717 https://github.com/jeetsukumaran/DendroPy/blob/49cf2cc2/src/dendropy/datamodel/treemodel/_tree.py#L2616 

3718 

3719 This is a utility function that is used by gwas for the kinship 

3720 similarity matrix. 

3721 

3722 :param tree: dendropy.datamodel.treemodel._tree.Tree 

3723 :return: dendropy.datamodel.treemodel._tree.Tree 

3724 """ 

3725 from dendropy.calculate.phylogeneticdistance import PhylogeneticDistanceMatrix 

3726 pdm = PhylogeneticDistanceMatrix.from_tree(tree) 

3727 

3728 tax1, tax2 = pdm.max_pairwise_distance_taxa() 

3729 plen = float(pdm.patristic_distance(tax1, tax2)) / 2 

3730 

3731 n1 = tree.find_node_with_taxon_label(tax1.label) 

3732 n2 = tree.find_node_with_taxon_label(tax2.label) 

3733 

3734 # Find taxa furthest from the root 

3735 mrca = pdm.mrca(tax1, tax2) 

3736 n1_dist, n2_dist = n1.distance_from_root(), n2.distance_from_root() 

3737 

3738 # Start at furthest taxa, walk up to root, look for possible edge  

3739 # representing midpoint 

3740 node = n1 if n1_dist >= n2_dist else n2 

3741 remaining_len = plen 

3742 

3743 while node != mrca: 

3744 edge_length = node.edge.length 

3745 remaining_len -= edge_length 

3746 # Found midpoint node 

3747 if remaining_len == 0: 

3748 node = node._parent_node 

3749 logging.info(f"Rerooting on midpoint node: {node}") 

3750 tree.reroot_at_node(node, update_bipartitions=True) 

3751 break 

3752 # Found midpoint edge 

3753 elif remaining_len < 0: 

3754 l2 = edge_length + remaining_len 

3755 l1 = edge_length - l2 

3756 logging.info(f"Rerooting on midpoint edge, l1={l1}, l2={l2}") 

3757 tree.reroot_at_edge(node.edge, update_bipartitions=True, length1=l1, length2=l2 ) 

3758 break 

3759 node = node._parent_node 

3760 

3761 return tree 

3762 

3763 

3764def read_tree(tree, tree_format:str="newick"): 

3765 """Read a tree from file into an object, and clean up labels.""" 

3766 from dendropy import Tree 

3767 

3768 tree = Tree.get(path=tree, schema=tree_format, tree_offset=0, preserve_underscores=True) 

3769 tree.is_rooted = True 

3770 

3771 # Give tip nodes the name of their taxon, cleanup tip labels 

3772 for node in list(tree.preorder_node_iter()): 

3773 if node.taxon: 

3774 node.label = str(node.taxon) 

3775 node.label = node.label.replace("\"", "").replace("'", "") if node.label != None else None 

3776 

3777 return tree 

3778 

3779 

3780def tree( 

3781 alignment: str, 

3782 outdir: str = ".", 

3783 prefix: str = None, 

3784 threads: int = 1, 

3785 constant_sites: str = None, 

3786 args: str = TREE_ARGS, 

3787 ): 

3788 """ 

3789 Estimate a maximum-likelihood tree with IQ-TREE. 

3790 

3791 Takes as input a multiple sequence alignment in FASTA format. If a SNP 

3792 alignment is provided, an optional text file of constant sites can be  

3793 included for correction. Outputs a maximum-likelihood tree, as well as  

3794 additional rooted trees if an outgroup is specified in the iqtree args. 

3795 

3796 tree(alignment="snps.core.fasta", constant_sites="snps.constant_sites.txt") 

3797 tree(alignment="pangenome.aln", threads=4, args='--ufboot 1000 -o sample1') 

3798 

3799 :param alignment: Path to multiple sequence alignment for IQ-TREE. 

3800 :param outdir: Output directory. 

3801 :param prefix: Prefix for output files. 

3802 :param threads: CPU threads for IQ-TREE. 

3803 :param constant_sites: Path to text file of constant site corrections. 

3804 :param args: Str of additional arguments for IQ-TREE. 

3805 """ 

3806 

3807 alignment_path, constant_sites_path, output_dir = alignment, constant_sites, outdir 

3808 args = args if args != None else "" 

3809 prefix = f"{prefix}." if prefix != None else "" 

3810 

3811 # Check output directory 

3812 check_output_dir(output_dir) 

3813 

3814 # Check for seed and threads conflict 

3815 if int(threads) > 1 and "--seed" in args: 

3816 logging.warning( 

3817 f""" 

3818 Using multiple threads is not guaranteed to give reproducible 

3819 results, even when using the --seed argument: 

3820 See this issue for more information: https://github.com/iqtree/iqtree2/discussions/233 

3821 """) 

3822 args += f" -T {threads}" 

3823 

3824 # Check for outgroup rooting 

3825 if "-o " in args: 

3826 outgroup = extract_cli_param(args, "-o") 

3827 logging.info(f"Tree will be rooted on outgroup: {outgroup}") 

3828 outgroup_labels = outgroup.split(",") 

3829 else: 

3830 outgroup = None 

3831 outgroup_labels = [] 

3832 

3833 # Check for constant sites input for SNP 

3834 if constant_sites_path: 

3835 with open(constant_sites_path) as infile: 

3836 constant_sites = infile.read().strip() 

3837 constant_sites_arg = f"-fconst {constant_sites}" 

3838 else: 

3839 constant_sites_arg = "" 

3840 

3841 # ------------------------------------------------------------------------- 

3842 # Estimate maximum likelihood tree 

3843 iqtree_prefix = os.path.join(output_dir, f"{prefix}tree") 

3844 run_cmd(f"iqtree -s {alignment_path} --prefix {iqtree_prefix} {constant_sites_arg} {args}") 

3845 tree_path = f"{iqtree_prefix}.treefile" 

3846 

3847 # Check for outgroup warnings in log 

3848 log_path = f"{iqtree_prefix}.log" 

3849 with open(log_path) as infile: 

3850 for line in infile: 

3851 outgroup_msg = "Branch separating outgroup is not found" 

3852 if outgroup_msg in line: 

3853 msg = f"Branch separating outgroup is not found. Please check if your ({outgroup}) is monophyletic in: {tree_path}" 

3854 logging.error(msg) 

3855 raise Exception(msg) 

3856 

3857 # ------------------------------------------------------------------------- 

3858 # Fix root position  

3859 

3860 tree_path = root_tree(tree=tree_path, outgroup=outgroup_labels, outdir=outdir, prefix=f"{prefix}tree") 

3861 tree = read_tree(tree_path, tree_format="newick") 

3862 tips = sorted([n.label for n in tree.leaf_node_iter()]) 

3863 # Sort child nodes, this extra sort should help us  

3864 # organize tip nodes by their branch length? 

3865 for node in tree.preorder_node_iter(): 

3866 node._child_nodes.sort( 

3867 key=lambda node: getattr(getattr(node, "taxon", None), "label", ""), 

3868 reverse=True 

3869 ) 

3870 

3871 tree_path = f"{iqtree_prefix}.rooted.nwk" 

3872 logging.info(f"Writing rooted tree: {tree_path}") 

3873 # Suppress the [&R] prefix, which causes problems in downstream programs  

3874 tree.write(path=tree_path, schema="newick", suppress_rooting=True) 

3875 

3876 # ------------------------------------------------------------------------- 

3877 # Tidy branch supports 

3878 

3879 node_i = 0 

3880 

3881 branch_support_path = tree_path.replace(".treefile", ".branch_support.tsv").replace(".nwk", ".branch_support.tsv") 

3882 logging.info(f"Extracting branch supports to table: {branch_support_path}") 

3883 with open(branch_support_path, 'w') as outfile: 

3884 header = ["node", "original", "ufboot", "alrt", "significant"] 

3885 outfile.write('\t'.join(header) + '\n') 

3886 for n in tree.preorder_node_iter(): 

3887 # Sanity check quotations 

3888 if n.label: 

3889 n.label = n.label.replace("'", "") 

3890 if not n.is_leaf(): 

3891 if n.label: 

3892 original = n.label 

3893 else: 

3894 original = "NA" 

3895 support = n.label 

3896 significant = "" 

3897 ufboot = "NA" 

3898 alrt = "NA" 

3899 if support: 

3900 if "/" in support: 

3901 ufboot = float(support.split("/")[0]) 

3902 alrt = float(support.split("/")[1]) 

3903 if ufboot > 95 and alrt > 80: 

3904 significant = "*" 

3905 else: 

3906 ufboot = float(support.split("/")[0]) 

3907 alrt = "NA" 

3908 if ufboot > 95: 

3909 significant = "*" 

3910 n.label = f"NODE_{node_i}" 

3911 row = [n.label, str(original), str(ufboot), str(alrt), significant] 

3912 outfile.write('\t'.join(row) + '\n') 

3913 node_i += 1 

3914 

3915 # Tree V1: Branch support is replaced by node labels 

3916 node_labels_path = tree_path.replace(".treefile", ".labelled_nodes.nwk").replace(".nwk", ".labelled_nodes.nwk") 

3917 logging.info(f"Writing tree with node labels: {node_labels_path}") 

3918 tree.write(path=node_labels_path, schema="newick", suppress_rooting=True) 

3919 

3920 # Tree V2: No branch support or node labels (plain) 

3921 plain_path = node_labels_path.replace(".labelled_nodes.nwk", ".plain.nwk") 

3922 logging.info(f"Writing plain tree: {plain_path}") 

3923 for n in tree.preorder_node_iter(): 

3924 n.label = None 

3925 tree.write(path=plain_path, schema="newick", suppress_rooting=True) 

3926 return tree_path 

3927 

3928 

3929def get_delim(table:str): 

3930 if table.endswith(".tsv") or table.endswith(".Rtab"): 

3931 return "\t" 

3932 elif table.endswith(".csv"): 

3933 return "," 

3934 else: 

3935 msg = f"Unknown file extension of table: {table}" 

3936 logging.error(msg) 

3937 raise Exception(msg) 

3938 

3939def table_to_rtab( 

3940 table: str, 

3941 filter: str, 

3942 outdir: str = ".", 

3943 prefix: str = None, 

3944 args: str = None, 

3945 ): 

3946 """ 

3947 Convert a TSV/CSV table to an Rtab file based on regex filters. 

3948 

3949 Takes as input a TSV/CSV table to convert, and a TSV/CSV of regex filters. 

3950 The filter table should have a header with the fields: column, regex, name. Where column 

3951 is the 'column' to search, 'regex' is the regular expression pattern, and 

3952 'name' is how the output variant should be named in the Rtab. 

3953 

3954 An example `filter.tsv` might look like this: 

3955 

3956 column regex name extra 

3957 assembly .*sample2.* sample2 data that will be ignored 

3958 lineage .*2.* lineage_2 more data 

3959 

3960 Where the goal is to filter the assembly and lineage columns for particular values. 

3961 

3962 >>> table_to_rtab(table="samplesheet.csv", filter="filter.tsv") 

3963 """ 

3964 

3965 table_path, filter_path, output_dir = table, filter, outdir 

3966 prefix = f"{prefix}." if prefix != None else "" 

3967 

3968 check_output_dir(output_dir) 

3969 

3970 logging.info(f"Checking delimiters.") 

3971 table_delim = get_delim(table_path) 

3972 filter_delim = get_delim(filter_path) 

3973 

3974 logging.info(f"Reading filters: {filter_path}") 

3975 filters = OrderedDict() 

3976 with open(filter_path) as infile: 

3977 header = [h.strip() for h in infile.readline().split(filter_delim)] 

3978 col_i, regex_i, name_i = [header.index(c) for c in ["column", "regex", "name"]] 

3979 for line in infile: 

3980 row = [l.strip() for l in line.split(filter_delim)] 

3981 column,regex,name = row[col_i], row[regex_i], row[name_i] 

3982 

3983 logging.debug(f"name: {name}, column: {column}, regex: {regex}") 

3984 if name in filters: 

3985 msg = f"Filter name {name} is not unique." 

3986 logging.error(msg) 

3987 raise Exception(msg) 

3988 filters[name] = {k:v for k,v in zip(header, row)} 

3989 

3990 logging.info(f"Searching input table: {table_path}") 

3991 rtab_data = OrderedDict({name:{} for name in filters}) 

3992 all_samples = [] 

3993 with open(table_path) as infile: 

3994 header = [h.strip() for h in infile.readline().split(table_delim)] 

3995 # Locate filter columns in the file 

3996 filter_names = list(filters.keys()) 

3997 for name in filter_names: 

3998 column = filters[name]["column"] 

3999 if column in header: 

4000 filters[name]["i"] = header.index(column) 

4001 else: 

4002 logging.warning(f"Filter column {column} is not present in the table.") 

4003 del filters[name] 

4004 if len(filters) == 0: 

4005 msg = "No filters were found matching columns in the input table." 

4006 logging.error(msg) 

4007 raise Exception(msg) 

4008 # Parse the table lines 

4009 for line in infile: 

4010 row = [l.strip() for l in line.split(table_delim)] 

4011 sample = row[0] 

4012 if sample not in all_samples: 

4013 all_samples.append(sample) 

4014 for name,data in filters.items(): 

4015 regex = data["regex"] 

4016 text = row[data["i"]] 

4017 val = 1 if re.match(regex, text) else 0 

4018 rtab_data[name][sample] = val 

4019 

4020 rtab_path = os.path.join(output_dir, f"{prefix}output.Rtab") 

4021 extended_path = os.path.join(output_dir, f"{prefix}output.tsv") 

4022 output_delim = "\t" 

4023 logging.info(f"Writing Rtab output: {rtab_path}") 

4024 logging.info(f"Writing extended output: {extended_path}") 

4025 with open(rtab_path, 'w') as rtab_outfile: 

4026 header = ["Variant"] + [str(s) for s in all_samples] 

4027 rtab_outfile.write(output_delim.join(header) + "\n") 

4028 with open(extended_path, 'w') as extended_outfile: 

4029 first_filter = list(filters.keys())[0] 

4030 header = ["sample"] + [col for col in filters[first_filter].keys() if col != "i"] 

4031 extended_outfile.write(output_delim.join(header) + "\n") 

4032 

4033 for name,data in rtab_data.items(): 

4034 row = [name] + [str(data[s]) for s in all_samples] 

4035 line = output_delim.join([str(val) for val in row]) 

4036 rtab_outfile.write(line + "\n") 

4037 

4038 positive_samples = [s for s,o in data.items() if str(o) == "1"] 

4039 for sample in positive_samples: 

4040 row = [sample] + [v for k,v in filters[name].items() if k != "i"] 

4041 line = output_delim.join([str(val) for val in row]) 

4042 extended_outfile.write(line + "\n") 

4043 

4044def vcf_to_rtab( 

4045 vcf: str, 

4046 bed: str = None, 

4047 outdir: str = ".", 

4048 prefix: str = None, 

4049 args: str = None, 

4050 ): 

4051 """ 

4052 Convert a VCF file to an Rtab file. 

4053 

4054 >>> vcf_to_rtab(vcf="snps.csv") 

4055 """ 

4056 

4057 vcf_path, bed_path, output_dir = vcf, bed, outdir 

4058 prefix = f"{prefix}." if prefix != None else "" 

4059 check_output_dir(output_dir) 

4060 

4061 bed = OrderedDict() 

4062 if bed_path != None: 

4063 logging.info(f"Reading BED: {bed_path}") 

4064 with open(bed_path) as infile: 

4065 lines = infile.readlines() 

4066 for line in tqdm(lines): 

4067 row = [l.strip() for l in line.split("\t")] 

4068 _chrom, start, end, name = row[0], row[1], row[2], row[3] 

4069 bed[name] = {"start": int(start) + 1, "end": int(end) + 1} 

4070 

4071 logging.info(f"Reading VCF: {vcf_path}") 

4072 

4073 all_samples = [] 

4074 

4075 rtab_path = os.path.join(output_dir, f"{prefix}output.Rtab") 

4076 

4077 with open(vcf_path) as infile: 

4078 with open(rtab_path, 'w') as outfile: 

4079 logging.info(f"Counting lines in VCF.") 

4080 lines = infile.readlines() 

4081 

4082 logging.info(f"Parsing variants.") 

4083 for line in tqdm(lines): 

4084 if line.startswith("##"): continue 

4085 # Write the Rtab header 

4086 if line.startswith("#"): 

4087 header = [l.strip().replace("#", "") for l in line.split("\t")] 

4088 # Samples start after 'FORMAT' 

4089 samples_i = header.index("FORMAT") + 1 

4090 all_samples = header[samples_i:] 

4091 line = "\t".join(["Variant"] + [str(s) for s in all_samples]) 

4092 outfile.write(line + "\n") 

4093 continue 

4094 

4095 row = [l.strip() for l in line.split("\t")] 

4096 data = {k:v for k,v in zip(header, row)} 

4097 chrom, pos = f"{data['CHROM']}", int(data['POS']) 

4098 ref = data['REF'] 

4099 alt = [nuc for nuc in data['ALT'].split(",") if nuc != "."] 

4100 

4101 # Skip multiallelic/missing 

4102 if len(alt) != 1: continue 

4103 

4104 alt = alt[0] 

4105 observations = [] 

4106 for sample in all_samples: 

4107 genotype = data[sample] 

4108 if genotype == "0/0": 

4109 value = "0" 

4110 elif genotype == "1/1": 

4111 value = "1" 

4112 else: 

4113 value = "." 

4114 

4115 observations.append(value) 

4116 

4117 for name,region in bed.items(): 

4118 start, end = region["start"], region["end"] 

4119 if pos >= start and pos < end: 

4120 chrom = name 

4121 pos = (pos - start) + 1 

4122 break 

4123 

4124 variant = f"{chrom}|snp:{ref}{pos}{alt}" 

4125 

4126 line = "\t".join([variant] + observations) 

4127 outfile.write(line + "\n") 

4128 

4129def binarize( 

4130 table: str, 

4131 column: str, 

4132 outdir: str = ".", 

4133 prefix: str = None, 

4134 output_delim: str = "\t", 

4135 column_prefix: str = None, 

4136 transpose: bool = False, 

4137 ): 

4138 """ 

4139 Convert a categorical column to multiple binary (0/1) columns. 

4140 

4141 Takes as input a table (tsv or csv) as well as a column to binarize. 

4142 

4143 >>> binarize(table="samplesheet.csv", column="lineage", output_delim=",") 

4144 >>> binarize(table="samplesheet.csv", column="resistant", output_delim="\t", transpose=True) 

4145 

4146 :param table: Path to table (TSV or CSV) 

4147 :param column: Name of column in table. 

4148 :param outdir: Path to the output directory. 

4149 """ 

4150 

4151 table_path, output_dir = table, outdir 

4152 column_prefix = column_prefix if column_prefix else "" 

4153 prefix = f"{prefix}." if prefix != None else "" 

4154 output_ext = "tsv" if output_delim == "\t" else "csv" if output_delim == "," else "txt" 

4155 

4156 # Check output directory 

4157 check_output_dir(output_dir) 

4158 input_delim = get_delim(table_path) 

4159 

4160 logging.info(f"Reading table: {table_path}") 

4161 

4162 all_samples = [] 

4163 values = {} 

4164 

4165 with open(table_path) as infile: 

4166 header = infile.readline().strip().split(input_delim) 

4167 sample_i = 0 

4168 column_i = header.index(column) 

4169 for line in infile: 

4170 line = [v.strip() for v in line.split(input_delim)] 

4171 sample, val = line[sample_i], line[column_i] 

4172 if sample not in all_samples: 

4173 all_samples.append(sample) 

4174 if val == "": continue 

4175 if val not in values: 

4176 values[val] = [] 

4177 values[val].append(sample) 

4178 

4179 # Sort values 

4180 values = OrderedDict({v:values[v] for v in sorted(values.keys())}) 

4181 

4182 binarize_path = os.path.join(output_dir, f"{prefix}binarize.{output_ext}") 

4183 logging.info(f"Writing binarized output: {binarize_path}") 

4184 with open(binarize_path, 'w') as outfile: 

4185 first_col = column if transpose else header[sample_i] 

4186 if transpose: 

4187 out_header = [first_col] + all_samples 

4188 outfile.write(output_delim.join(out_header) + "\n") 

4189 for val,samples in values.items(): 

4190 observations = ["1" if sample in samples else "0" for sample in all_samples] 

4191 row = [f"{column_prefix}{val}"] + observations 

4192 outfile.write(output_delim.join(row) + "\n") 

4193 else: 

4194 out_header = [first_col] + [f"{column_prefix}{k}" for k in values.keys()] 

4195 outfile.write(output_delim.join(out_header) + "\n") 

4196 for sample in all_samples: 

4197 row = [sample] 

4198 for val,val_samples in values.items(): 

4199 row.append("1") if sample in val_samples else row.append("0") 

4200 outfile.write(output_delim.join(row) + "\n") 

4201 

4202 

4203def qq_plot( 

4204 locus_effects:str, 

4205 outdir: str = ".", 

4206 prefix: str = None, 

4207 ): 

4208 """ 

4209 Modified version of pyseer's qq_plot function. 

4210 

4211 Source: https://github.com/mgalardini/pyseer/blob/master/scripts/qq_plot.py 

4212 """ 

4213 

4214 import matplotlib.pyplot as plt 

4215 import numpy as np 

4216 import statsmodels 

4217 import statsmodels.api as sm 

4218 import pandas as pd 

4219 

4220 locus_effects_path, output_dir = locus_effects, outdir 

4221 prefix = f"{prefix}." if prefix != None else "" 

4222 

4223 logging.info(f"Reading locus effects: {locus_effects_path}") 

4224 lrt_pvalues = [] 

4225 with open(locus_effects_path) as infile: 

4226 header = [l.strip() for l in infile.readline().split("\t")] 

4227 for line in infile: 

4228 row = [l.strip() for l in line.split("\t")] 

4229 data = {k:v for k,v in zip(header, row)} 

4230 lrt_pvalue = data["lrt-pvalue"] 

4231 if lrt_pvalue != "": 

4232 lrt_pvalues.append(float(lrt_pvalue)) 

4233 

4234 # Exclude p-value=0, which has no log 

4235 y = -np.log10([val for val in lrt_pvalues if val != 0]) 

4236 x = -np.log10(np.random.uniform(0, 1, len(y))) 

4237 

4238 # check for statsmodels version (issue #212) 

4239 old_stats = False 

4240 try: 

4241 vmajor, vminor = statsmodels.__version__.split('.')[:2] 

4242 if int(vmajor) == 0 and int(vminor) < 13: 

4243 old_stats = True 

4244 else: 

4245 old_stats = False 

4246 except: 

4247 msg = "Failed to identify the statsmodel version, QQ plot will not be created." 

4248 logging.warning(msg) 

4249 return None 

4250 

4251 if old_stats: 

4252 xx = y 

4253 yy = x 

4254 else: 

4255 xx = x 

4256 yy = y 

4257 

4258 # Plot 

4259 logging.info(f"Creating QQ plot.") 

4260 plt.figure(figsize=(4, 3.75)) 

4261 ax = plt.subplot(111) 

4262 

4263 fig = sm.qqplot_2samples(xx, 

4264 yy, 

4265 xlabel='Expected $-log_{10}(pvalue)$', 

4266 ylabel='Observed $-log_{10}(pvalue)$', 

4267 line='45', 

4268 ax=ax) 

4269 

4270 ax = fig.axes[0] 

4271 ax.lines[0].set_color('k') 

4272 ax.lines[0].set_alpha(0.3) 

4273 

4274 # Handle inf values if a p-value was 0 

4275 x_max = x.max() if not np.isinf(x.max()) else max([val for val in x if not np.isinf(val)]) 

4276 y_max = y.max() if not np.isinf(y.max()) else max([val for val in y if not np.isinf(val)]) 

4277 ax.set_xlim(-0.5, x_max + 0.5) 

4278 ax.set_ylim(-0.5, y_max + 0.5) 

4279 

4280 plt.tight_layout() 

4281 

4282 plot_path = os.path.join(output_dir, f"{prefix}qq_plot.png") 

4283 logging.info(f"Writing plot: {plot_path}") 

4284 plt.savefig(plot_path, dpi=150) 

4285 plt.close('all') 

4286 

4287 return plot_path 

4288 

4289 

4290def gwas( 

4291 variants: str, 

4292 table: str, 

4293 column: str, 

4294 clusters: str = None, 

4295 continuous: str = False, 

4296 tree: str = None, 

4297 midpoint: bool= True, 

4298 outdir: str = ".", 

4299 prefix: str = None, 

4300 lineage_column: str = None, 

4301 exclude_missing: bool = False, 

4302 threads: int = 1, 

4303 args: str = GWAS_ARGS, 

4304 ): 

4305 """ 

4306 Run genome-wide association study (GWAS) tests with pyseer. 

4307 

4308 Takes as input the TSV file of summarized clusters, an Rtab file of variants, 

4309 a TSV/CSV table of phenotypes, and a column name representing the trait of interest. 

4310 Outputs tables of locus effects, and optionally lineage effects (bugwas) if specified. 

4311 

4312 >>> gwas(clusters='summarize.clusters.tsv', variants='combine.Rtab', table='samplesheet.csv', column='resistant') 

4313 >>> gwas(clusters='summarize.clusters.tsv', variants='combine.Rtab', table='samplesheet.csv', column='lineage', args='--no-distances') 

4314 >>> gwas(clusters='summarize.clusters.tsv', variants='combine.Rtab', table='samplesheet.csv', column='resistant', args='--lmm --min-af 0.05') 

4315 

4316 :param variants: Path to Rtab file of variants. 

4317 :param table: Path to TSV/CSV table of traits. 

4318 :param column: Column name of variable in table. 

4319 :param clusters: Path to TSV file of summarized clusters.  

4320 :param continuous: Treat column as a continuous variable. 

4321 :param tree: Path to newick tree. 

4322 :param midpoint: True if the tree should be rerooted at the midpoint. 

4323 :param outdir: Output directory. 

4324 :param prefix: Prefix for output files. 

4325 :param lineage_column: Name of lineage column in table (enables bugwas lineage effects). 

4326 :param exclude_missing: Exclude samples missing phenotype data. 

4327 :param threads: CPU threads for pyseer. 

4328 :param args: Str of additional arguments for pyseer. 

4329 """ 

4330 

4331 clusters_path, variants_path, table_path, tree_path, output_dir = clusters, variants, table, tree, outdir 

4332 

4333 # Check output directory 

4334 check_output_dir(output_dir) 

4335 

4336 # Read in the table that contains the phenotypes (optional lineage column) 

4337 logging.info(f"Reading table: {table_path}") 

4338 with open(table_path) as infile: 

4339 table = infile.read().replace(",", "\t") 

4340 table_split = table.split("\n") 

4341 table_header = table_split[0].split("\t") 

4342 table_rows = [[l.strip() for l in line.split("\t") ] for line in table_split[1:] if line.strip() != ""] 

4343 

4344 # Is the user didn't provide any custom args initialize as empty 

4345 args = args if args != None else "" 

4346 args += f" --cpu {threads}" 

4347 prefix = f"{prefix}." if prefix != None else "" 

4348 

4349 # Check for conflicts between input arguments 

4350 if "--no-distances" not in args: 

4351 if tree_path == None and "--lmm" in args: 

4352 msg = "You must supply a phylogeny if you have not requested --no-distances." 

4353 logging.error(msg) 

4354 raise Exception(msg) 

4355 else: 

4356 if lineage_column: 

4357 msg = "The param --no-distances cannot be used if a lineage column was specified." 

4358 logging.error(msg) 

4359 raise Exception(msg) 

4360 

4361 if "--wg enet" in args: 

4362 msg = "The whole genome elastic net model is not implemented yet." 

4363 logging.error(msg) 

4364 raise Exception(msg) 

4365 

4366 # If a lineage column was specified, extract it from the table into  

4367 # it's own file for pyseer. 

4368 lineage_path = None 

4369 sample_i = 0 

4370 

4371 if lineage_column: 

4372 lineage_path = os.path.join(output_dir, f"{prefix}{column}.lineages.tsv") 

4373 logging.info(f"Extracting lineage column {lineage_column}: {lineage_path}") 

4374 with open(lineage_path, 'w') as lineage_file: 

4375 lineage_i = table_header.index(lineage_column) 

4376 for row in table_rows: 

4377 sample,lineage = row[sample_i], row[lineage_i] 

4378 lineage_file.write(f"{sample}\t{lineage}\n") 

4379 

4380 # ------------------------------------------------------------------------- 

4381 # Table / Phenotype 

4382 

4383 if column not in table_header: 

4384 msg = f"Column {column} is not present in the table: {table_path}" 

4385 logging.error(msg) 

4386 raise Exception(msg) 

4387 

4388 column_i = table_header.index(column) 

4389 

4390 # Exclude samples that are missing phenotype if requested 

4391 exclude_samples = set() 

4392 if exclude_missing: 

4393 logging.info("Excluding samples missing phenotype data.") 

4394 table_filter = [] 

4395 for row in table_rows: 

4396 sample = row[sample_i] 

4397 value = row[column_i] 

4398 if value == "NA" or value == "": 

4399 exclude_samples.add(sample) 

4400 continue 

4401 table_filter.append(row) 

4402 table_rows = table_filter 

4403 logging.info(f"Number of samples excluded: {len(exclude_samples)}") 

4404 

4405 # Convert the phenotypes table from csv to tsv for pyseer 

4406 if table_path.endswith(".csv"): 

4407 logging.info(f"Converting csv to tsv for pyseer: {table_path}") 

4408 file_name = os.path.basename(table_path) 

4409 table_tsv_path = os.path.join(output_dir, f"{prefix}{column}." + file_name.replace(".csv", ".tsv")) 

4410 with open(table_tsv_path, 'w') as outfile: 

4411 outfile.write("\t".join(table_header) + "\n") 

4412 for row in table_rows: 

4413 outfile.write("\t".join(row) + "\n") 

4414 

4415 # Check for categorical, binary, continuous trait 

4416 logging.info(f"Checking type of column: {column}") 

4417 

4418 observations = OrderedDict() 

4419 for row in table_rows: 

4420 sample, val = row[sample_i], row[column_i] 

4421 observations[sample] = val 

4422 

4423 all_samples = list(observations.keys()) 

4424 unique_observations = sorted(list(set(observations.values()))) 

4425 

4426 if len(unique_observations) <=1 : 

4427 msg = f"Column {column} must have at least two different values." 

4428 logging.error(msg) 

4429 raise Exception(msg) 

4430 

4431 # Now remove missing data 

4432 unique_observations = [o for o in unique_observations if o != ""] 

4433 if unique_observations == ["0", "1"] or unique_observations == ["0"] or unique_observations == ["1"]: 

4434 logging.info(f"Column {column} is a binary variable with values 0/1.") 

4435 unique_observations = [column] 

4436 elif continuous == True: 

4437 logging.info(f"Treating column {column} as continuous.") 

4438 try: 

4439 _check = [float(o) for o in unique_observations] 

4440 except: 

4441 msg = f"Failed to convert all values of {column} as numeric." 

4442 logging.error(msg) 

4443 raise Exception(msg) 

4444 unique_observations = [column] 

4445 args += " --continuous" 

4446 else: 

4447 logging.info(f"Binarizing categorical column {column} to multiple binary columns.") 

4448 binarize_prefix = f"{prefix}{column}" 

4449 binarize(table = table_tsv_path, column = column, column_prefix=f"{column}_", outdir=output_dir, prefix=binarize_prefix) 

4450 table_tsv_path = os.path.join(output_dir, f"{binarize_prefix}.binarize.tsv") 

4451 unique_observations = [f"{column}_{o}" for o in unique_observations] 

4452 

4453 # ------------------------------------------------------------------------- 

4454 # (optional) Cluster annotations 

4455 

4456 clusters = OrderedDict() 

4457 clusters_header = [] 

4458 

4459 if clusters_path != None: 

4460 logging.info(f"Reading summarized clusters: {clusters_path}") 

4461 with open(clusters_path) as infile: 

4462 clusters_header = [line.strip() for line in infile.readline().split("\t")] 

4463 

4464 lines = infile.readlines() 

4465 for line in tqdm(lines): 

4466 row = [l.strip() for l in line.split("\t")] 

4467 # Store the cluster data, exclude all sample columns, they will 

4468 # be provided by the variant input 

4469 data = { 

4470 k:v for k,v in zip(clusters_header, row) 

4471 if k not in exclude_samples and k not in all_samples 

4472 } 

4473 cluster = data["cluster"] 

4474 clusters[cluster] = data 

4475 

4476 # Exclude samples from clusters_header 

4477 clusters_header = [ 

4478 col for col in clusters_header 

4479 if col not in exclude_samples and col not in all_samples 

4480 ] 

4481 

4482 # ------------------------------------------------------------------------- 

4483 # Filter variants 

4484 

4485 # Exclude missing samples and invariants 

4486 # TBD: Might want to convert 'missing' chars all to '.' 

4487 

4488 logging.info(f"Reading variants: {variants_path}") 

4489 variants_filter_path = os.path.join(output_dir, f"{prefix}{column}.filter.Rtab") 

4490 logging.info(f"Writing filtered variants: {variants_filter_path}") 

4491 

4492 variants = {} 

4493 variants_header = [] 

4494 exclude_variants = set() 

4495 

4496 with open(variants_filter_path, "w") as outfile: 

4497 with open(variants_path) as infile: 

4498 header = infile.readline() 

4499 # Save the header before filtering, to match up values to samples 

4500 original_header = header.strip().split("\t")[1:] 

4501 # Filter samples in the variant files 

4502 variants_header = [s for s in original_header if s not in exclude_samples] 

4503 outfile.write("\t".join(["Variant"] + variants_header) + "\n") 

4504 

4505 # Filter invariants 

4506 lines = infile.readlines() 

4507 for line in tqdm(lines): 

4508 row = [r.strip() for r in line.split("\t")] 

4509 variant = row[0] 

4510 observations = [ 

4511 o 

4512 for s,o in zip(original_header, row[1:]) 

4513 if s not in exclude_samples 

4514 ] 

4515 unique_obs = set([o for o in observations]) 

4516 # Exclude simple invariants (less two possibilities) and invariants 

4517 # with missing, ex. [".", "0"], [".", "1"] 

4518 if (len(unique_obs) < 2) or (len(unique_obs) == 2 and ("." in unique_obs)): 

4519 exclude_variants.add(variant) 

4520 else: 

4521 variants[variant] = observations 

4522 filter_row = [variant] + observations 

4523 outfile.write("\t".join(filter_row) + "\n") 

4524 

4525 logging.info(f"Number of invariants excluded: {len(exclude_variants)}") 

4526 logging.info(f"Number of variants included: {len(variants)}") 

4527 variants_path = variants_filter_path 

4528 

4529 # ------------------------------------------------------------------------- 

4530 # Distance Matrices 

4531 

4532 patristic_path = os.path.join(output_dir, f"{prefix}{column}.patristic.tsv") 

4533 kinship_path = os.path.join(output_dir, f"{prefix}{column}.kinship.tsv") 

4534 

4535 if "--lmm" in args: 

4536 logging.info("Running pyseer with similarity matrix due to: --lmm") 

4537 args += f" --similarity {kinship_path}" 

4538 if lineage_path: 

4539 logging.info("Running pyseer with patristic distances due to: --lineage-column") 

4540 args += f" --distances {patristic_path}" 

4541 

4542 # The user might want to run with --no-distances if population structure (lineage) 

4543 # is the variable of interest. 

4544 if ("--distances" not in args and 

4545 "--similarity" not in args): 

4546 logging.info("Running pyseer with --no-distances") 

4547 #args += " --no-distances" 

4548 # Otherwise, we setup the distance matrices as needed 

4549 else: 

4550 logging.info(f"Creating distance matrices: {tree_path}") 

4551 # A modified version of pyseer's phylogeny_distance function 

4552 # Source: https://github.com/mgalardini/pyseer/blob/master/scripts/phylogeny_distance.py 

4553 tree = read_tree(tree_path, tree_format="newick") 

4554 

4555 if exclude_missing: 

4556 logging.info(f"Excluding missing samples from the tree: {tree_path}") 

4557 tree.prune_taxa_with_labels(labels=exclude_samples) 

4558 tree.prune_leaves_without_taxa() 

4559 

4560 tree_path = os.path.join(output_dir, f"{prefix}{column}.filter.nwk") 

4561 tree.write(path=tree_path, schema="newick", suppress_rooting=True) 

4562 

4563 if midpoint == True: 

4564 tree_path = os.path.join(output_dir, f"{prefix}{column}.midpoint.nwk") 

4565 logging.info(f"Rerooting tree at midpoint: {tree_path}") 

4566 tree = root_tree_midpoint(tree) 

4567 tree.write(path=tree_path, schema="newick", suppress_rooting=True) 

4568 

4569 # Reread in tree after all that adjustments 

4570 tree = read_tree(tree_path, tree_format="newick") 

4571 

4572 patristic = OrderedDict() 

4573 kinship = OrderedDict() 

4574 distance_matrix = tree.phylogenetic_distance_matrix() 

4575 

4576 for n1 in tree.taxon_namespace: 

4577 if n1.label not in all_samples: 

4578 logging.warning(f"Sample {n1.label} is present in the tree but not in the phenotypes, excluding from distance matrices.") 

4579 continue 

4580 kinship[n1.label] = kinship.get(n1.label, OrderedDict()) 

4581 patristic[n1.label] = patristic.get(n1.label, OrderedDict()) 

4582 for n2 in tree.taxon_namespace: 

4583 if n2.label not in all_samples: 

4584 continue 

4585 # Measure kinship/similarity as distance to root of MRCA 

4586 # TBD: Investigating impact of precision/significant digits 

4587 # https://github.com/mgalardini/pyseer/issues/286 

4588 if n2.label not in kinship[n1.label].keys(): 

4589 mrca = distance_matrix.mrca(n1, n2) 

4590 distance = mrca.distance_from_root() 

4591 kinship[n1.label][n2.label] = round(distance, 8) 

4592 # Measure patristic distance between the nodes 

4593 if n2.label not in patristic[n1.label].keys(): 

4594 distance = distance_matrix.patristic_distance(n1, n2) 

4595 patristic[n1.label][n2.label] = round(distance, 8) 

4596 

4597 if len(patristic) == 0 or len(kinship) == 0: 

4598 msg = "No samples are present in the distance matrices! Please check that your table sample labels match the tree." 

4599 logging.error(msg) 

4600 raise Exception(msg) 

4601 

4602 logging.info(f"Saving patristic distances to: {patristic_path}") 

4603 logging.info(f"Saving similarity kinship to: {kinship_path}") 

4604 

4605 with open(patristic_path, 'w') as patristic_file: 

4606 with open(kinship_path, 'w') as kinship_file: 

4607 distance_samples = [s for s in all_samples if s in kinship] 

4608 distance_header = "\t" + "\t".join(distance_samples) 

4609 patristic_file.write(distance_header + "\n") 

4610 kinship_file.write(distance_header + "\n") 

4611 # Row order based on all_samples for consistency 

4612 for s1 in distance_samples: 

4613 # Kinship/similarity 

4614 row = [s1] + [str(kinship[s1][s2]) for s2 in distance_samples] 

4615 kinship_file.write("\t".join(row) + "\n") 

4616 # Patristic 

4617 row = [s1] + [str(patristic[s1][s2]) for s2 in distance_samples] 

4618 patristic_file.write("\t".join(row) + "\n") 

4619 

4620 # ------------------------------------------------------------------------- 

4621 # Pyseer GWAS 

4622 

4623 for new_column in unique_observations: 

4624 logging.info(f"Running GWAS on column: {new_column}") 

4625 

4626 log_path = os.path.join(output_dir, f"{prefix}{new_column}.pyseer.log") 

4627 focal_path = os.path.join(output_dir, f"{prefix}{new_column}.focal.txt") 

4628 patterns_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.patterns.txt") 

4629 locus_effects_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.tsv") 

4630 locus_significant_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.significant.tsv") 

4631 locus_filter_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.significant.filter.tsv") 

4632 lineage_effects_path = os.path.join(output_dir, f"{prefix}{new_column}.lineage_effects.tsv") 

4633 lineage_significant_path = os.path.join(output_dir, f"{prefix}{new_column}.lineage_effects.significant.tsv") 

4634 bonferroni_path = os.path.join(output_dir, f"{prefix}{new_column}.locus_effects.bonferroni.txt") 

4635 cmd = textwrap.dedent( 

4636 f"""\ 

4637 pyseer {args} 

4638 --pres {variants_path} 

4639 --phenotypes {table_tsv_path} 

4640 --phenotype-column {new_column} 

4641 --output-patterns {patterns_path} 

4642 """) 

4643 if lineage_path: 

4644 cmd += f" --lineage --lineage-clusters {lineage_path} --lineage-file {lineage_effects_path}" 

4645 cmd = cmd.replace("\n", " ") 

4646 cmd_pretty = cmd + f" 1> {locus_effects_path} 2> {log_path}" 

4647 logging.info(f"pyseer command: {cmd_pretty}") 

4648 run_cmd(cmd, output=locus_effects_path, err=log_path, quiet=True, display_cmd=False) 

4649 

4650 # Record which samples are 'positive' for this trait 

4651 positive_samples = [] 

4652 negative_samples = [] 

4653 missing_samples = [] 

4654 with open(table_tsv_path) as infile: 

4655 header = [l.strip() for l in infile.readline().split("\t")] 

4656 for line in infile: 

4657 row = [l.strip() for l in line.split("\t")] 

4658 data = {k:v for k,v in zip(header, row)} 

4659 sample, observation = data["sample"], data[new_column] 

4660 if observation != ".": 

4661 if float(observation) != 0: 

4662 positive_samples.append(sample) 

4663 else: 

4664 negative_samples.append(sample) 

4665 else: 

4666 missing_samples.append(sample) 

4667 

4668 logging.info(f"{len(positive_samples)}/{len(all_samples)} samples are positive (>0) for {new_column}.") 

4669 logging.info(f"{len(negative_samples)}/{len(all_samples)} samples are negative (=0) for {new_column}.") 

4670 logging.info(f"{len(missing_samples)}/{len(all_samples)} samples are missing (.) for {new_column}.") 

4671 

4672 with open(focal_path, 'w') as outfile: 

4673 outfile.write("\n".join(positive_samples)) 

4674 # ------------------------------------------------------------------------- 

4675 # Significant threshold 

4676 

4677 logging.info(f"Determining significance threshold.") 

4678 

4679 with open(patterns_path) as infile: 

4680 patterns = list(set([l.strip() for l in infile.readlines()])) 

4681 num_patterns = len(patterns) 

4682 logging.info(f"Number of patterns: {num_patterns}") 

4683 bonferroni = 0.05 / float(num_patterns) if num_patterns > 0 else 0 

4684 logging.info(f"Bonferroni threshold (0.05 / num_patterns): {bonferroni}") 

4685 with open(bonferroni_path, 'w') as outfile: 

4686 outfile.write(str(bonferroni) + "\n") 

4687 

4688 # ------------------------------------------------------------------------- 

4689 # Extract variants into dict 

4690 

4691 logging.info(f"Extracting variants.") 

4692 

4693 gwas_variants = OrderedDict() 

4694 # Keep track of the smallest non-zero pvalue for the log10 transformation 

4695 # of 0 values 

4696 min_pvalue = 1 

4697 all_pvalues = set() 

4698 

4699 with open(locus_effects_path) as infile: 

4700 locus_effects_header = infile.readline().strip().strip().split("\t") 

4701 for line in infile: 

4702 row = [l.strip() for l in line.split("\t")] 

4703 data = {k:v for k,v in zip(locus_effects_header, row)} 

4704 data["notes"] = ",".join(sorted(data["notes"].split(","))) 

4705 # Convert pvalue 

4706 pvalue = data["lrt-pvalue"] 

4707 if pvalue != "": 

4708 pvalue = float(pvalue) 

4709 # Keep track of the smallest non-zero pvalue 

4710 if pvalue != 0: 

4711 min_pvalue = min(pvalue, min_pvalue) 

4712 all_pvalues.add(pvalue) 

4713 data["-log10(p)"] = "" 

4714 data["bonferroni"] = bonferroni 

4715 variant = data["variant"] 

4716 gwas_variants[variant] = data 

4717 locus_effects_header += ["-log10(p)", "bonferroni"] 

4718 

4719 logging.info(f"Minimum pvalue observed: {min(all_pvalues)}") 

4720 logging.info(f"Minimum pvalue observed (non-zero): {min_pvalue}") 

4721 

4722 # ------------------------------------------------------------------------- 

4723 # -log10 transformation 

4724 logging.info("Applying -log10 transformation.") 

4725 

4726 # For p-values that are 0, we will use this tiny value instead 

4727 # Remember, the min_pvalue is the minimum pvalue that is NOT 0 

4728 small_val = float("1.00E-100") 

4729 if small_val < min_pvalue: 

4730 if 0 in all_pvalues: 

4731 logging.warning(f"Using small float for pvalue=0 transformations: {small_val}") 

4732 else: 

4733 logging.warning(f"Using pvalue min for pvalue=0 transformations: {min_pvalue}") 

4734 small_val = min_pvalue 

4735 

4736 for variant, v_data in gwas_variants.items(): 

4737 pvalue = v_data["lrt-pvalue"] 

4738 if pvalue == "": continue 

4739 pvalue = float(pvalue) 

4740 if pvalue == 0: 

4741 logging.warning(f"{variant} has a pvalue of 0. log10 transformation will use: {small_val}") 

4742 pvalue = small_val 

4743 gwas_variants[variant]["-log10(p)"] = -math.log10(pvalue) 

4744 

4745 # ------------------------------------------------------------------------- 

4746 # Annotate and sort the output files 

4747 

4748 # Extract the cluster information for each variant 

4749 logging.info(f"Extracting cluster identifiers.") 

4750 locus_effects = OrderedDict() 

4751 

4752 for variant,v_data in gwas_variants.items(): 

4753 variant_split = variant.split("|") 

4754 cluster = variant_split[0] 

4755 row = list(v_data.values()) 

4756 if cluster not in locus_effects: 

4757 locus_effects[cluster] = [] 

4758 locus_effects[cluster].append(row) 

4759 

4760 logging.info(f"Adding cluster annotations and variant observations.") 

4761 locus_effects_rows = [] 

4762 match_header = ["match", "mismatch", "score"] 

4763 for cluster in tqdm(locus_effects): 

4764 cluster_row = [clusters[cluster][col] for col in clusters_header] 

4765 for effect_row in locus_effects[cluster]: 

4766 effect_data = {k:v for k,v in zip(locus_effects_header, effect_row)} 

4767 try: 

4768 beta = float(effect_data["beta"]) 

4769 except ValueError: 

4770 beta = "" 

4771 variant = effect_row[0] 

4772 variant_row = variants[variant] 

4773 variant_data = {k:v for k,v in zip(variants_header, variant_row)} 

4774 # A simple summary statistic of how well the variant observations match the phenotype 

4775 denom = len(all_samples) 

4776 # Presence of variant should 'increase' or 'decrease' phenotype 

4777 if beta == "" or beta > 0: 

4778 v1, v2 = ["1", "0"] 

4779 else: 

4780 v1, v2 = ["0", "1"] 

4781 match_num = len([ 

4782 k for k,v in variant_data.items() 

4783 if (v == v1 and k in positive_samples) or (v == v2 and k in negative_samples) 

4784 ]) 

4785 mismatch_num = len([ 

4786 k for k,v in variant_data.items() 

4787 if (v == v1 and k not in positive_samples) or (v == v2 and k not in negative_samples) 

4788 ]) 

4789 

4790 match_value = f"{match_num}/{denom}|{round(match_num / denom, 2)}" 

4791 mismatch_value = f"{mismatch_num}/{denom}|{round(mismatch_num / denom, 2)}" 

4792 match_score = (match_num/denom) - (mismatch_num/denom) 

4793 # TBD What if we are on operating on lrt-filtering-fail p values? 

4794 # We have no way to know the beta direction, and which way the score should go... 

4795 if beta == "": 

4796 match_score = abs(match_score) 

4797 match_row = [match_value, mismatch_value, round(match_score, 2)] 

4798 locus_effects_rows += [effect_row + match_row + cluster_row + variant_row] 

4799 

4800 logging.info(f"Sorting locus effects by lrt-pvalue: {locus_effects_path}") 

4801 locus_pvalues = OrderedDict() 

4802 locus_null = OrderedDict() 

4803 for i,row in enumerate(locus_effects_rows): 

4804 data = {k:v for k,v in zip(locus_effects_header, row)} 

4805 variant, lrt_pvalue = data["variant"], data["lrt-pvalue"] 

4806 if lrt_pvalue == "": 

4807 if lrt_pvalue not in locus_null: 

4808 locus_null[lrt_pvalue] = [] 

4809 locus_null[lrt_pvalue].append((variant, i)) 

4810 else: 

4811 lrt_pvalue = float(lrt_pvalue) 

4812 if lrt_pvalue not in locus_pvalues: 

4813 locus_pvalues[lrt_pvalue] = [] 

4814 locus_pvalues[lrt_pvalue].append((variant, i)) 

4815 locus_pvalues = sorted(locus_pvalues.items(), key=lambda kv: (kv[0], kv[1])) 

4816 locus_null = sorted(locus_null.items(), key=lambda kv: (kv[0], kv[1])) 

4817 locus_order = [] 

4818 for data in locus_pvalues + locus_null: 

4819 for variant, i in data[1]: 

4820 locus_order.append(i) 

4821 locus_effects_rows = [locus_effects_rows[i] for i in locus_order] 

4822 

4823 # Save the results 

4824 with open(locus_effects_path, 'w') as outfile: 

4825 header = "\t".join(locus_effects_header + match_header + clusters_header + variants_header) 

4826 outfile.write(header + "\n") 

4827 for row in locus_effects_rows: 

4828 line = "\t".join([str(r) for r in row]) 

4829 outfile.write(line + "\n") 

4830 

4831 logging.info(f"Sorting patterns: {patterns_path}") 

4832 with open(patterns_path) as infile: 

4833 lines = sorted([l for l in infile.readlines()]) 

4834 with open(patterns_path, 'w') as outfile: 

4835 outfile.write("".join(lines)) 

4836 

4837 if lineage_path: 

4838 logging.info(f"Sorting lineage effects: {lineage_effects_path}") 

4839 with open(lineage_effects_path) as infile: 

4840 header = infile.readline() 

4841 lines = sorted([l for l in infile.readlines()]) 

4842 with open(lineage_effects_path, 'w') as outfile: 

4843 outfile.write(header + "".join(lines)) 

4844 

4845 # ------------------------------------------------------------------------- 

4846 # Locus Effects 

4847 

4848 logging.info("Identifying significant locus effects.") 

4849 with open(locus_effects_path) as infile: 

4850 header_line = infile.readline() 

4851 header = header_line.strip().split("\t") 

4852 with open(locus_significant_path, 'w') as significant_outfile: 

4853 significant_outfile.write(header_line) 

4854 with open(locus_filter_path, 'w') as filter_outfile: 

4855 filter_outfile.write(header_line) 

4856 for line in infile: 

4857 row = [l.strip() for l in line.split("\t")] 

4858 data = {k:v for k,v in zip(header, row)} 

4859 lrt_pvalue, notes = data["lrt-pvalue"], data["notes"] 

4860 if lrt_pvalue != "" and float(lrt_pvalue) < bonferroni: 

4861 significant_outfile.write(line) 

4862 if "bad-chisq" not in notes and "high-bse" not in notes: 

4863 filter_outfile.write(line) 

4864 

4865 # ------------------------------------------------------------------------- 

4866 # Lineage effects 

4867 

4868 if lineage_path: 

4869 logging.info("Identifying significant lineage effects.") 

4870 with open(lineage_effects_path) as infile: 

4871 header_line = infile.readline() 

4872 header = header_line.strip().split("\t") 

4873 pvalue_i = header.index("p-value") 

4874 with open(lineage_significant_path, 'w') as significant_outfile: 

4875 significant_outfile.write(header_line) 

4876 for line in infile: 

4877 row = [l.strip() for l in line.split("\t")] 

4878 pvalue = float(row[pvalue_i]) 

4879 if pvalue < bonferroni: 

4880 significant_outfile.write(line) 

4881 

4882 # ------------------------------------------------------------------------- 

4883 # QQ Plot 

4884 

4885 if len(locus_effects_rows) == 0: 

4886 logging.info("Skipping QQ Plot as no variants were observed.") 

4887 else: 

4888 logging.info("Creating QQ Plot.") 

4889 _plot_path = qq_plot( 

4890 locus_effects = locus_effects_path, 

4891 outdir = output_dir, 

4892 prefix = f"{prefix}{new_column}" 

4893 ) 

4894 

4895 

4896def text_to_path(text, tmp_path="tmp.svg", family="Roboto", size=16, clean=True): 

4897 """ 

4898 Convert a string of text to SVG paths. 

4899 

4900 :param text: String of text. 

4901 :param tmp_path: File path to temporary file to render svg. 

4902 :param family: Font family. 

4903 :param size: Font size. 

4904 :param clean: True if temporary file should be deleted upon completion. 

4905 """ 

4906 

4907 import cairo 

4908 from xml.dom import minidom 

4909 from svgpathtools import parse_path 

4910 

4911 

4912 # Approximate and appropriate height and width for the tmp canvas 

4913 # This just needs to be at least big enough to hold it 

4914 w = size * len(text) 

4915 h = size * 2 

4916 # Render text as individual glyphs 

4917 with cairo.SVGSurface(tmp_path, w, h) as surface: 

4918 context = cairo.Context(surface) 

4919 context.move_to(0, h/2) 

4920 context.set_font_size(size) 

4921 context.select_font_face(family) 

4922 context.show_text(text) 

4923 

4924 # Parse text data and positions from svg DOM 

4925 doc = minidom.parse(tmp_path) 

4926 

4927 # Keep track of overall text bbox 

4928 t_xmin, t_xmax, t_ymin, t_ymax = None, None, None, None 

4929 glyphs = [] 

4930 

4931 # Absolute positions of each glyph are under <use> 

4932 # ex. <use xlink:href="#glyph-0-1" x="11.96875" y="16"/> 

4933 for u in doc.getElementsByTagName('use'): 

4934 glyph_id = u.getAttribute("xlink:href").replace("#", "") 

4935 # Get absolute position of the glyph 

4936 x = float(u.getAttribute("x")) 

4937 y = float(u.getAttribute("y")) 

4938 # Get relative path data of the glyph 

4939 for g in doc.getElementsByTagName('g'): 

4940 g_id = g.getAttribute("id") 

4941 if g_id != glyph_id: continue 

4942 for path in g.getElementsByTagName('path'): 

4943 d = path.getAttribute("d") 

4944 p = parse_path(d) 

4945 # Get glyph relative bbox 

4946 xmin, xmax, ymin, ymax = [c for c in p.bbox()] 

4947 # Convert bbox to absolute coordinates 

4948 xmin, xmax, ymin, ymax = xmin + x, xmax + x, ymin + y, ymax + y 

4949 # Update the text bbox 

4950 if t_xmin == None: 

4951 t_xmin, t_xmax, t_ymin, t_ymax = xmin, xmax, ymin, ymax 

4952 else: 

4953 t_xmin = min(t_xmin, xmin) 

4954 t_xmax = max(t_xmax, xmax) 

4955 t_ymin = min(t_ymin, ymin) 

4956 t_ymax = max(t_ymax, ymax) 

4957 bbox = xmin, xmax, ymin, ymax 

4958 

4959 w, h = xmax - xmin, ymax - ymin 

4960 glyph = {"id": glyph_id, "x": x, "y": y, "w": w, "h": h, "d": d, "bbox": bbox} 

4961 glyphs.append(glyph) 

4962 

4963 if not t_xmax: 

4964 logging.error(f"Failed to render text: {text}") 

4965 raise Exception(f"Failed to render text: {text}") 

4966 # Calculate the final dimensions of the entire text 

4967 w, h = t_xmax - t_xmin, t_ymax - t_ymin 

4968 bbox = t_xmin, t_xmax, t_ymin, t_ymax 

4969 

4970 result = {"text": text, "glyphs": glyphs, "w": w, "h": h, "bbox": bbox } 

4971 

4972 if clean: 

4973 os.remove(tmp_path) 

4974 

4975 return result 

4976 

4977 

4978def linear_scale(value, value_min, value_max, target_min=0.0, target_max=1.0): 

4979 if value_min == value_max: 

4980 return value 

4981 else: 

4982 return (value - value_min) * (target_max - target_min) / (value_max - value_min) + target_min 

4983 

4984def manhattan( 

4985 gwas: str, 

4986 bed: str, 

4987 outdir: str = ".", 

4988 prefix: str = None, 

4989 font_size: int = 16, 

4990 font_family: str ="Roboto", 

4991 margin: int = 20, 

4992 width: int = 1000, 

4993 height: int = 500, 

4994 png_scale: float = 2.0, 

4995 prop_x_axis: bool = False, 

4996 ymax: float = None, 

4997 max_blocks: int = 20, 

4998 syntenies: bool = ["all"], 

4999 clusters: bool = ["all"], 

5000 variant_types: list = ["all"], 

5001 args: str = None, 

5002 ): 

5003 """ 

5004 Plot the distribution of variant p-values across the genome. 

5005 

5006 Takes as input a table of locus effects from the subcommand and a bed 

5007 file such as the one producted by the align subcommand. Outputs a  

5008 manhattan plot in SVG and PNG format. 

5009 

5010 >>> manhattan(gwas="locus_effects.tsv", bed="pangenome.bed") 

5011 >>> manhattan(gwas="locus_effects.tsv", bed="pangenome.bed", syntenies=["chromosome"], clusters=["pbpX"], variant_types=["snp", "presence_absence"]) 

5012  

5013 :param gwas: Path to tsv table of locus effects from gwas subcommand. 

5014 :param bed: Path to BED file of coordinates. 

5015 :param outdir: Output directory. 

5016 :param prefix: Output prefix. 

5017 :param width: Width in pixels of plot. 

5018 :param height: Height in pixels of plot. 

5019 :param margin: Size in pixels of plot margins. 

5020 :param font_size: Font size. 

5021 :param font_family: Font family. 

5022 :param png_scale: Float that adjusts png scale relative to svg. 

5023 :param prop_x_axis: Scale x-axis based on the length of the synteny block. 

5024 :param ymax: Maximum value for the y-axis -log10(p). 

5025 :param sytenies: Names of synteny blocks to plot. 

5026 :param clusters: Names of clusters to plot. 

5027 :params variant_types: Names of variant types to plot. 

5028 """ 

5029 

5030 import numpy as np 

5031 

5032 logging.info(f"Importing cairo.") 

5033 import cairosvg 

5034 

5035 gwas_path, bed_path, output_dir = gwas, bed, outdir 

5036 prefix = f"{prefix}." if prefix != None else "" 

5037 check_output_dir(outdir) 

5038 

5039 # Type conversions fallback 

5040 width, height, margin = int(width), int(height), int(margin) 

5041 png_scale = float(png_scale) 

5042 if ymax != None: ymax = float(ymax) 

5043 

5044 # handle space values if given 

5045 syntenies = syntenies if type(syntenies) != str else syntenies.split(" ") 

5046 syntenies = [str(s) for s in syntenies] 

5047 

5048 clusters = clusters if type(clusters) != str else clusters.split(" ") 

5049 clusters = [str(c) for c in clusters] 

5050 

5051 variant_types = variant_types if type(variant_types) != str else variant_types.split(" ") 

5052 variant_types = [str(v) for v in variant_types] 

5053 

5054 plot_data = OrderedDict() 

5055 

5056 # ------------------------------------------------------------------------- 

5057 # Parse BED (Synteny Blocks) 

5058 

5059 pangenome_length = 0 

5060 

5061 logging.info(f"Reading BED coordinates: {bed_path}") 

5062 with open(bed_path) as infile: 

5063 for line in infile: 

5064 row = [l.strip() for l in line.split("\t")] 

5065 # Adjust 0-based bed coordinates back to 1 base 

5066 c_start, c_end, name = int(row[1]) + 1, int(row[2]), row[3] 

5067 pangenome_length = c_end 

5068 c_length = (c_end - c_start) + 1 

5069 # ex. cluster=geneA;synteny=443 

5070 info = {n.split("=")[0]:n.split("=")[1] for n in name.split(";")} 

5071 cluster, synteny = info["cluster"], info["synteny"] 

5072 

5073 if synteny not in plot_data: 

5074 plot_data[synteny] = { 

5075 "pangenome_pos": [c_start, c_end], 

5076 "length": c_length, 

5077 "variants": OrderedDict(), 

5078 "clusters": OrderedDict(), 

5079 "ticks": OrderedDict(), 

5080 "prop": 0.0, 

5081 } 

5082 

5083 # update the end position and length of the current synteny block 

5084 plot_data[synteny]["pangenome_pos"][1] = c_end 

5085 plot_data[synteny]["length"] = ( 

5086 c_end - plot_data[synteny]["pangenome_pos"][0] 

5087 ) + 1 

5088 

5089 s_start = plot_data[synteny]["pangenome_pos"][0] 

5090 c_data = { 

5091 "pangenome_pos": [c_start, c_end], 

5092 "synteny_pos": [1 + c_start - s_start, 1+ c_end - s_start], 

5093 "cluster_pos": [1, c_length], 

5094 "length": c_length, 

5095 } 

5096 plot_data[synteny]["clusters"][cluster] = c_data 

5097 

5098 # ------------------------------------------------------------------------- 

5099 # Parse GWAS (Variants) 

5100 

5101 logging.info(f"Reading GWAS table: {gwas_path}") 

5102 

5103 alpha = None 

5104 log10_pvalues = set() 

5105 

5106 with open(gwas_path) as infile: 

5107 

5108 header = [l.strip() 

5109 for l in infile.readline().split("\t")] 

5110 if "cluster" not in header: 

5111 msg = "GWAS table must contain cluster annotations for manhattan plot." 

5112 logging.error(msg) 

5113 raise Exception(msg) 

5114 

5115 lines = infile.readlines() 

5116 

5117 if len(lines) == 0: 

5118 msg = "GWAS table contains no variants." 

5119 logging.warning(msg) 

5120 return 0 

5121 

5122 for line in tqdm(lines): 

5123 row = [l.strip() for l in line.split("\t")] 

5124 data = {k:v for k,v in zip(header, row)} 

5125 variant, cluster, synteny = data["variant"], data["cluster"], data["synteny"] 

5126 

5127 alpha = float(data["bonferroni"]) if alpha == None else alpha 

5128 

5129 if data["lrt-pvalue"] == "": continue 

5130 pvalue = float(data["lrt-pvalue"]) 

5131 log10_pvalue = float(data["-log10(p)"]) 

5132 log10_pvalues.add(log10_pvalue) 

5133 

5134 # try to find matching synteny block by name 

5135 c_data = None 

5136 if synteny in plot_data: 

5137 c_data = plot_data[synteny]["clusters"][cluster] 

5138 # Otherwise, find by cluster 

5139 else: 

5140 for s,s_data in plot_data.items(): 

5141 if cluster in s_data["clusters"]: 

5142 synteny = s 

5143 c_data = s_data["clusters"][cluster] 

5144 

5145 if c_data == None: 

5146 msg = f"Synteny block {synteny} for {cluster} is not present in the BED file." 

5147 logging.error(msg) 

5148 raise Exception(msg) 

5149 

5150 s_data = plot_data[synteny] 

5151 s_start, s_end = s_data["pangenome_pos"] 

5152 c_start, c_end = c_data["synteny_pos"] 

5153 

5154 # 3 different coordinates systems 

5155 cluster_coords = [] 

5156 synteny_coords = [] 

5157 pangenome_coords = [] 

5158 

5159 if "|snp" in variant: 

5160 variant_type = "snp" 

5161 snp = variant.split("|")[1].split(":")[1] 

5162 pos = int("".join([c for c in snp if c.isnumeric()])) 

5163 cluster_pos = [pos, pos] 

5164 cluster_coords.append(cluster_pos) 

5165 elif "|presence_absence" in variant: 

5166 variant_type = "presence_absence" 

5167 cluster_pos = c_data["cluster_pos"] 

5168 cluster_coords.append(cluster_pos) 

5169 elif "|structural" in variant: 

5170 variant_type = "structural" 

5171 variant_split = variant.split("|")[1].split(":")[1] 

5172 for interval in variant_split.split("_"): 

5173 cluster_pos = [int(v) for v in interval.split("-")] 

5174 cluster_coords.append(cluster_pos) 

5175 else: 

5176 logging.warning(f"Skipping unknown variant type: {variant}") 

5177 continue 

5178 

5179 # Apply filter 

5180 if "all" not in variant_types and variant_type not in variant_types: continue 

5181 if "all" not in syntenies and synteny not in syntenies: continue 

5182 if "all" not in clusters and cluster not in clusters: continue 

5183 

5184 # Convert cluster coords to synteny and pangenome coords 

5185 for pos in cluster_coords: 

5186 synteny_pos = [c_start + pos[0], c_start + pos[1]] 

5187 synteny_coords.append(synteny_pos) 

5188 pangenome_pos = [s_start + synteny_pos[0], s_start + synteny_pos[1]] 

5189 pangenome_coords.append(pangenome_pos) 

5190 

5191 plot_data[synteny]["variants"][variant] = { 

5192 "variant": variant, 

5193 "synteny": synteny, 

5194 "synteny_pos": int(data["synteny_pos"]), 

5195 "cluster": cluster, 

5196 "pvalue": pvalue, 

5197 "-log10(p)": log10_pvalue, 

5198 "variant_h2": float(data["variant_h2"]), 

5199 "type": variant_type, 

5200 "cluster_coords": cluster_coords, 

5201 "synteny_coords": synteny_coords, 

5202 "pangenome_coords": pangenome_coords, 

5203 } 

5204 

5205 alpha_log10 = -math.log10(alpha) 

5206 if ymax == None: 

5207 max_log10 = math.ceil(max(log10_pvalues) if max(log10_pvalues) > alpha_log10 else alpha_log10) 

5208 else: 

5209 max_log10 = ymax 

5210 

5211 min_log10 = 0 

5212 log10_vals= np.arange(min_log10, max_log10 + 1, max_log10 / 4) 

5213 

5214 # ------------------------------------------------------------------------- 

5215 # Exclude synteny blocks with no variants 

5216 plot_data = OrderedDict({k:v for k,v in plot_data.items() if len(v["variants"]) > 0}) 

5217 

5218 # Adjust the total length based on just the synteny blocks we've observed 

5219 total_length = sum( 

5220 [sum([c["length"] for c in s["clusters"].values()]) 

5221 for s in plot_data.values()] 

5222 ) 

5223 

5224 observed_clusters = set() 

5225 for s in plot_data.values(): 

5226 for v in s["variants"].values(): 

5227 observed_clusters.add(v["cluster"]) 

5228 for s,s_data in plot_data.items(): 

5229 plot_data[s]["clusters"] = OrderedDict({ 

5230 k:v for k,v in plot_data[s]["clusters"].items() 

5231 if k in observed_clusters 

5232 }) 

5233 

5234 # If there's just a single cluster, adjust the total length 

5235 if len(observed_clusters) == 1: 

5236 only_synteny = list(plot_data.keys())[0] 

5237 only_cluster = list(observed_clusters)[0] 

5238 total_length = plot_data[only_synteny]["clusters"][only_cluster]["length"] 

5239 elif "all" not in clusters: 

5240 total_length = sum( 

5241 [sum([c["length"] for c in s["clusters"].values()]) 

5242 for s in plot_data.values()] 

5243 ) 

5244 

5245 if len(plot_data) == 0: 

5246 msg = "No variants remain after filtering." 

5247 logging.warning(msg) 

5248 return None 

5249 

5250 for s,s_data in plot_data.items(): 

5251 s_length = sum([c["length"] for c in s_data["clusters"].values()]) 

5252 s_prop = s_length / total_length 

5253 plot_data[s]["prop"] = s_prop 

5254 

5255 # ------------------------------------------------------------------------- 

5256 # Phandango data 

5257 

5258 phandango_path = os.path.join(output_dir, f"{prefix}phandango.plot") 

5259 logging.info(f"Creating phandango input: {phandango_path}") 

5260 

5261 with open(phandango_path, 'w') as outfile: 

5262 header = ["#CHR", "SNP", "BP", "minLOG10(P)", "log10(p)", "r^2"] 

5263 outfile.write("\t".join(header) + "\n") 

5264 

5265 for synteny,s_data in plot_data.items(): 

5266 pangenome_pos = s_data["pangenome_pos"] 

5267 for variant, v_data in s_data["variants"].items(): 

5268 log10_pvalue = v_data["-log10(p)"] 

5269 variant_h2 = v_data["variant_h2"] 

5270 cluster = v_data["cluster"] 

5271 c_data = s_data["clusters"][cluster] 

5272 

5273 # If a snp, we take the actual position 

5274 # for presence/absence or structural, we take the start 

5275 if "|snp" in variant: 

5276 pos = v_data["pangenome_coords"][0][0] 

5277 else: 

5278 pos = c_data["pangenome_pos"][0] 

5279 row = ["pangenome", variant, pos, log10_pvalue, log10_pvalue, variant_h2] 

5280 line = "\t".join([str(v) for v in row]) 

5281 outfile.write(line + "\n") 

5282 

5283 # ------------------------------------------------------------------------- 

5284 # Y-Axis Label Dimensions 

5285 

5286 logging.info(f"Calculating y-axis label dimensions.") 

5287 

5288 # Main y-axis label 

5289 y_axis_text = "-log10(p)" 

5290 y_axis_label = text_to_path(text=y_axis_text, size=font_size, family=font_family) 

5291 y_axis_label_h = y_axis_label["h"] 

5292 y_axis_label_x = margin + y_axis_label_h 

5293 y_axis_label["rx"] = y_axis_label_x 

5294 # To set the y position, we'll need to wait until after x axis labels are done 

5295 

5296 y_tick_hmax = 0 

5297 y_tick_wmax = 0 

5298 y_tick_labels = OrderedDict() 

5299 

5300 for val in log10_vals: 

5301 label = text_to_path(text=str(val), size=font_size * 0.75, family=font_family) 

5302 w, h = label["w"], label["h"] 

5303 y_tick_hmax, y_tick_wmax = max(y_tick_hmax, h), max(y_tick_wmax, w) 

5304 y_tick_labels[val] = label 

5305 

5306 # We now know some initial dimensions 

5307 tick_len = y_tick_hmax / 2 

5308 tick_pad = tick_len 

5309 y_axis_x = y_axis_label_x + y_axis_label_h + y_tick_wmax + tick_pad + tick_len 

5310 y_axis_y1 = margin 

5311 # To set the y2 position, we'll need to wait until after x axis labels are done 

5312 

5313 # ------------------------------------------------------------------------- 

5314 # X-Axis Label Dimensions 

5315 

5316 logging.info(f"Calculating x-axis label dimensions.") 

5317 

5318 # X-axis main label 

5319 if len(plot_data) > max_blocks: 

5320 x_axis_text = "Pangenome" 

5321 elif len(plot_data) == 1: 

5322 x_axis_text = str(list(plot_data.keys())[0]) 

5323 if len(observed_clusters) == 1: 

5324 x_axis_text = str(list(observed_clusters)[0]) 

5325 else: 

5326 x_axis_text = "Synteny Block" 

5327 

5328 x_axis_label = text_to_path(text=x_axis_text, size=font_size, family=font_family) 

5329 x_axis_label_w, x_axis_label_h = x_axis_label["w"], x_axis_label["h"] 

5330 x_axis_label_y = height - margin 

5331 

5332 x_axis_x1 = y_axis_x 

5333 x_axis_x2 = width - margin 

5334 x_axis_w = x_axis_x2 - x_axis_x1 

5335 x_axis_step = x_axis_w / len(plot_data) 

5336 

5337 x_axis_label_x = x_axis_x1 + (x_axis_w / 2) - (x_axis_label_w / 2) 

5338 

5339 x_tick_hmax = 0 

5340 x_tick_wmax = 0 

5341 

5342 start_coord, end_coord = 0, total_length 

5343 

5344 

5345 if len(plot_data) > max_blocks or len(plot_data) == 1: 

5346 

5347 if len(plot_data) > max_blocks: 

5348 logging.info(f"Plot data exceeds max_blocks of {max_blocks}, enabling pangenome coordinates.") 

5349 start_coord, end_coord = 0, pangenome_length 

5350 elif len(plot_data) == 1: 

5351 only_synteny = list(plot_data.keys())[0] 

5352 logging.info(f"Single synteny blocked detected, enabling synteny coordinates.") 

5353 # For one cluster, we might have variable start positions 

5354 if len(observed_clusters) == 1: 

5355 only_cluster = list(observed_clusters)[0] 

5356 c_data = plot_data[only_synteny]["clusters"][only_cluster] 

5357 start_coord, end_coord = c_data["cluster_pos"] 

5358 # Round down/up to nearest 10? 

5359 start_coord = start_coord - (start_coord % 10) 

5360 end_coord = end_coord + (end_coord % 10) 

5361 

5362 label = text_to_path(text="0123456789", size=font_size * 0.75, family=font_family) 

5363 h = label["h"] 

5364 # Put 1/2 h spacing in between 

5365 max_labels = x_axis_w / (2 * h) 

5366 # Prefer either 10 or 4 ticks 

5367 if max_labels >= 10: 

5368 max_labels = 10 

5369 elif max_labels >= 4: 

5370 max_labels = 4 

5371 

5372 step_size = 1 

5373 num_labels = max_labels + 1 

5374 # 10, 50, 100, 500, 1000... 

5375 while num_labels > max_labels: 

5376 if str(step_size).startswith("5"): 

5377 step_size = step_size * 2 

5378 else: 

5379 step_size = step_size * 5 

5380 num_labels = total_length / step_size 

5381 

5382 x_tick_vals = list(range(start_coord, end_coord + 1, step_size)) 

5383 

5384 x_axis_step = x_axis_w / num_labels 

5385 for val in x_tick_vals: 

5386 text = str(val) 

5387 label = text_to_path(text=text, size=font_size * 0.75, family=font_family) 

5388 w, h = label["w"], label["h"] 

5389 x_tick_hmax, x_tick_wmax = max(x_tick_hmax, h), max(x_tick_wmax, w) 

5390 plot_data[synteny]["ticks"][text] = {"label": label} 

5391 else: 

5392 for synteny in plot_data: 

5393 label = text_to_path(text=str(synteny), size=font_size * 0.75, family=font_family) 

5394 w, h = label["w"], label["h"] 

5395 x_tick_hmax, x_tick_wmax = max(x_tick_hmax, h), max(x_tick_wmax, w) 

5396 plot_data[synteny]["ticks"][synteny] = {"label": label} 

5397 

5398 

5399 # We now have enough information to finalize the axes coordinates 

5400 x_axis_y = x_axis_label_y - (x_axis_label_h * 2) - x_tick_wmax - tick_pad - tick_len 

5401 y_axis_y2 = x_axis_y 

5402 y_axis_h = y_axis_y2 - y_axis_y1 

5403 y_axis_label_y = y_axis_y1 + (y_axis_h / 2) 

5404 y_axis_label["ry"] = y_axis_label_y 

5405 

5406 # ------------------------------------------------------------------------- 

5407 logging.info(f"Positioning labels.") 

5408 

5409 for i_g,g in enumerate(x_axis_label["glyphs"]): 

5410 x_axis_label["glyphs"][i_g]["x"] = x_axis_label_x + g["x"] 

5411 x_axis_label["glyphs"][i_g]["y"] = x_axis_label_y 

5412 

5413 for i_g,g in enumerate(y_axis_label["glyphs"]): 

5414 y_axis_label["glyphs"][i_g]["x"] = y_axis_label_x + g["x"] 

5415 y_axis_label["glyphs"][i_g]["y"] = y_axis_label_y 

5416 

5417 # ------------------------------------------------------------------------- 

5418 # X-Axis 

5419 

5420 logging.info("Creating x-axis.") 

5421 

5422 prev_x = x_axis_x1 

5423 for synteny in plot_data: 

5424 s_data = plot_data[synteny] 

5425 

5426 if len(plot_data) == 1: 

5427 plot_data[synteny]["x1"] = x_axis_x1 

5428 plot_data[synteny]["x2"] = x_axis_x2 

5429 

5430 for tick,t_data in plot_data[synteny]["ticks"].items(): 

5431 

5432 if prop_x_axis == True: 

5433 xw = x_axis_w * s_data["prop"] 

5434 else: 

5435 xw = x_axis_step 

5436 

5437 if len(plot_data) > 1 and len(plot_data) <= max_blocks: 

5438 plot_data[synteny]["x1"] = prev_x 

5439 plot_data[synteny]["x2"] = prev_x + xw 

5440 

5441 # Pixel coordinates of this synteny block 

5442 if len(plot_data) == 1 or len(plot_data) > max_blocks: 

5443 x = linear_scale(int(tick), start_coord, end_coord, x_axis_x1, x_axis_x2) 

5444 else: 

5445 x = prev_x + (xw / 2) 

5446 

5447 # Tick Label 

5448 label = t_data["label"] 

5449 y = x_axis_y + tick_len + label["w"] + (tick_pad) 

5450 t_data["label"]["rx"] = x 

5451 t_data["label"]["ry"] = y 

5452 

5453 # Tick Line 

5454 t_data["line"] = {"x": x, "y1": x_axis_y, "y2": x_axis_y + tick_len} 

5455 # Center label on the very last glyph 

5456 center_glyph = label["glyphs"][-1] 

5457 label_y = y + (center_glyph["h"] / 2) 

5458 for i_g,g in enumerate(label["glyphs"]): 

5459 t_data["label"]["glyphs"][i_g]["x"] = x + g["x"] 

5460 t_data["label"]["glyphs"][i_g]["y"] = label_y 

5461 

5462 plot_data[synteny]["ticks"][tick] = t_data 

5463 

5464 prev_x += xw 

5465 

5466 # ------------------------------------------------------------------------- 

5467 # Y-Axis 

5468 

5469 logging.info("Creating y-axis.") 

5470 

5471 y_axis = OrderedDict({k:{"label": v, "line": None} for k,v in y_tick_labels.items()}) 

5472 

5473 for val,label in y_tick_labels.items(): 

5474 y = linear_scale(val, max_log10, min_log10, y_axis_y1, y_axis_y2) 

5475 x1, x2 = y_axis_x - tick_len, y_axis_x 

5476 

5477 # Tick Line 

5478 y_axis[val]["line"] = { "x1": x1, "x2": x2, "y": y } 

5479 

5480 # Tick Label 

5481 y_axis[val]["label"] = label 

5482 center_glyph = label["glyphs"][-1] 

5483 label_y = y + (center_glyph["h"] / 2) 

5484 label_x = x1 - tick_pad - label["w"] 

5485 

5486 for i_g,g in enumerate(label["glyphs"]): 

5487 label["glyphs"][i_g]["x"] = label_x + g["x"] 

5488 label["glyphs"][i_g]["y"] = label_y 

5489 

5490 y_axis[val]["label"] = label 

5491 

5492 # ------------------------------------------------------------------------- 

5493 # Render 

5494 

5495 radius = 2 

5496 

5497 svg_path = os.path.join(output_dir, f"{prefix}plot.svg") 

5498 logging.info(f"Rendering output svg ({width}x{height}): {svg_path}") 

5499 with open(svg_path, 'w') as outfile: 

5500 header = textwrap.dedent( 

5501 f"""\ 

5502 <svg  

5503 version="1.1" 

5504 xmlns="http://www.w3.org/2000/svg" 

5505 xmlns:xlink="http://www.w3.org/1999/xlink" 

5506 preserveAspectRatio="xMidYMid meet" 

5507 width="{width}" 

5508 height="{height}" 

5509 viewbox="0 0 {width} {height}"> 

5510 """) 

5511 outfile.write(header) 

5512 

5513 indent = " " 

5514 

5515 # White canvas background 

5516 outfile.write(f"{indent * 1}<g id='Background'>\n") 

5517 background = f"{indent * 2}<rect width='{width}' height='{height}' x='0' y='0' style='fill:white;stroke-width:1;stroke:white'/>" 

5518 outfile.write(background + "\n") 

5519 outfile.write(f"{indent * 1}</g>\n") 

5520 

5521 # Axes 

5522 outfile.write(f"{indent * 1}<g id='axis'>\n") 

5523 

5524 # --------------------------------------------------------------------- 

5525 # X Axis 

5526 

5527 outfile.write(f"{indent * 2}<g id='x'>\n") 

5528 

5529 # Background panels 

5530 if len(plot_data) <= max_blocks: 

5531 outfile.write(f"{indent * 3}<g id='Background Panels'>\n") 

5532 for i_s,synteny in enumerate(plot_data): 

5533 if i_s % 2 == 0: continue 

5534 x1, x2 = plot_data[synteny]["x1"], plot_data[synteny]["x2"] 

5535 y = margin 

5536 w = x2 - x1 

5537 h = y_axis_h 

5538 rect = f"{indent * 2}<rect width='{w}' height='{h}' x='{x1}' y='{y}' style='fill:grey;fill-opacity:0.10'/>" 

5539 outfile.write(rect + "\n") 

5540 outfile.write(f"{indent * 3}</g>\n") 

5541 

5542 # X axis line 

5543 outfile.write(f"{indent * 3}<g id='line'>\n") 

5544 line = f"{indent * 4}<path d='M {x_axis_x1} {x_axis_y} H {x_axis_x2}' style='stroke:black;stroke-width:2;fill:none'/>" 

5545 outfile.write(line + "\n") 

5546 outfile.write(f"{indent * 3}</g>\n") 

5547 

5548 # X axis label 

5549 outfile.write(f"{indent * 3}<g id='label'>\n") 

5550 for g in x_axis_label["glyphs"]: 

5551 line = f"{indent * 4}<path transform='translate({g['x']},{g['y']})' d='{g['d']}'/>" 

5552 outfile.write(line + "\n") 

5553 outfile.write(f"{indent * 3}</g>\n") 

5554 

5555 # X axis ticks 

5556 outfile.write(f"{indent * 3}<g id='ticks'>\n") 

5557 for synteny in plot_data: 

5558 for tick,t_data in plot_data[synteny]["ticks"].items(): 

5559 # Tick line 

5560 t = t_data["line"] 

5561 line = f"{indent * 4}<path d='M {t['x']} {t['y1']} V {t['y2']}' style='stroke:black;stroke-width:1;fill:none'/>" 

5562 outfile.write(line + "\n") 

5563 # Tick label 

5564 label = t_data["label"] 

5565 rx, ry = label["rx"], label["ry"] 

5566 for g in label["glyphs"]: 

5567 line = f"{indent * 4}<path transform='rotate(-90, {rx}, {ry}) translate({g['x']},{g['y']})' d='{g['d']}'/>" 

5568 outfile.write(line + "\n") 

5569 

5570 outfile.write(f"{indent * 3}</g>\n") 

5571 

5572 # Close x-axis 

5573 outfile.write(f"{indent * 2}</g>\n") 

5574 

5575 # --------------------------------------------------------------------- 

5576 # Y Axis 

5577 

5578 outfile.write(f"{indent * 2}<g id='y-axis'>\n") 

5579 

5580 # Y-axis line 

5581 outfile.write(f"{indent * 3}<g id='line'>\n") 

5582 line = f"{indent * 4}<path d='M {y_axis_x} {y_axis_y1} V {y_axis_y2}' style='stroke:black;stroke-width:2;fill:none'/>" 

5583 outfile.write(line + "\n") 

5584 outfile.write(f"{indent * 3}</g>\n") 

5585 

5586 # Y axis label 

5587 outfile.write(f"{indent * 3}<g id='label'>\n") 

5588 for g in y_axis_label["glyphs"]: 

5589 rx, ry = y_axis_label["rx"], y_axis_label["ry"] 

5590 line = f"{indent * 4}<path transform='rotate(-90, {rx}, {ry}) translate({g['x']},{g['y']})' d='{g['d']}'/>" 

5591 outfile.write(line + "\n") 

5592 outfile.write(f"{indent * 3}</g>\n") 

5593 

5594 outfile.write(f"{indent * 3}<g id='ticks'>\n") 

5595 # Ticks 

5596 for t,t_data in y_axis.items(): 

5597 # Tick line 

5598 t = t_data["line"] 

5599 line = f"{indent * 3}<path d='M {t['x1']} {t['y']} H {t['x2']}' style='stroke:black;stroke-width:1;fill:none'/>" 

5600 outfile.write(line + "\n") 

5601 # Tick Label 

5602 label = t_data["label"] 

5603 for g in label["glyphs"]: 

5604 line = f"{indent * 4}<path transform='translate({g['x']},{g['y']})' d='{g['d']}'/>" 

5605 outfile.write(line + "\n") 

5606 outfile.write(f"{indent * 3}</g>\n") 

5607 

5608 # Close y-axis 

5609 outfile.write(f"{indent * 2}</g>\n") 

5610 

5611 # Close all Axes 

5612 outfile.write(f"{indent * 1}</g>\n") 

5613 

5614 

5615 # -------------------------------------------------------------  

5616 # Data 

5617 outfile.write(f"{indent * 1}<g id='variants'>\n") 

5618 

5619 for i_s, synteny in enumerate(plot_data): 

5620 s_data = plot_data[synteny] 

5621 if len(plot_data) <= max_blocks: 

5622 s_x1, s_x2 = s_data["x1"], s_data["x2"] 

5623 else: 

5624 s_x1, s_x2 = x_axis_x1, x_axis_x2 

5625 s_len = s_data["length"] 

5626 

5627 variants = plot_data[synteny]["variants"].values() 

5628 variants_order = [v for v in variants if v["type"] == "presence_absence"] 

5629 variants_order += [v for v in variants if v["type"] == "structural"] 

5630 variants_order += [v for v in variants if v["type"] == "snp"] 

5631 opacity = 0.50 

5632 

5633 for v_data in variants_order: 

5634 # If we're plotting according to pangenome coordinates 

5635 # we use the panGWAS green color 

5636 if len(plot_data) >= max_blocks: 

5637 f = "#356920" 

5638 # Otherwise we alternate between the classic blue and orange 

5639 else: 

5640 f = "#242bbd" if i_s % 2 == 0 else "#e37610" 

5641 vy = linear_scale(v_data["-log10(p)"], min_log10, max_log10, y_axis_y2, y_axis_y1) 

5642 variant = v_data["variant"].replace("'", "") 

5643 

5644 # In many browsers, the title will become hover text! 

5645 hover_text = [] 

5646 for k,v in v_data.items(): 

5647 if "coords" not in k: 

5648 text = f"{k}: {v}" 

5649 else: 

5650 coords = ", ".join(["-".join([str(xs) for xs in x]) for x in v]) 

5651 text = f"{k}: {coords}" 

5652 hover_text.append(text) 

5653 hover_text = "\n".join(hover_text) 

5654 title = f"{indent * 4}<title>{hover_text}</title>" 

5655 

5656 outfile.write(f"{indent * 2}<g id='{variant}'>\n") 

5657 

5658 if len(observed_clusters) == 1: 

5659 s_start = start_coord 

5660 s_end = end_coord 

5661 else: 

5662 s_start = 1 

5663 s_end = s_len 

5664 

5665 # Decide on the coordinate system 

5666 if len(plot_data) > max_blocks: 

5667 coordinate_system = "pangenome" 

5668 s_start, s_end = 0, pangenome_length 

5669 else: 

5670 coordinate_system = "synteny" 

5671 

5672 pos = v_data[f"{coordinate_system}_coords"][0][0] 

5673 vx = linear_scale(pos, s_start, s_end, s_x1, s_x2) 

5674 

5675 # Circle 

5676 if v_data["type"] == "snp": 

5677 circle = f"{indent * 3}<circle cx='{vx}' cy='{vy}' r='{radius}' style='fill:{f};fill-opacity:{opacity}'>" 

5678 outfile.write(circle + "\n") 

5679 outfile.write(title + "\n") 

5680 outfile.write(f"{indent * 3}</circle>" + "\n") 

5681 

5682 # Square 

5683 elif v_data["type"] == "presence_absence": 

5684 pos = v_data[f"{coordinate_system}_coords"][0][0] 

5685 vx = linear_scale(pos, s_start, s_end, s_x1, s_x2) 

5686 rect = f"{indent * 3}<rect width='{radius*2}' height='{radius*2}' x='{vx}' y='{vy - radius}' style='fill:{f};fill-opacity:{opacity}'>" 

5687 outfile.write(rect + "\n") 

5688 outfile.write(title + "\n") 

5689 outfile.write(f"{indent * 3}</rect>" + "\n") 

5690 

5691 # Diamond 

5692 elif v_data["type"] == "structural": 

5693 rx, ry = vx + radius, vy + radius 

5694 rect = f"{indent * 3}<rect transform='rotate(-45, {rx}, {ry})' width='{radius*2}' height='{radius*2}' x='{vx}' y='{vy - radius}' style='fill:{f};fill-opacity:{opacity}'>" 

5695 outfile.write(rect + "\n") 

5696 outfile.write(title + "\n") 

5697 outfile.write(f"{indent * 3}</rect>" + "\n") 

5698 

5699 outfile.write(f"{indent * 2}</g>\n") 

5700 outfile.write(f"{indent * 1}</g>\n") 

5701 

5702 # p-value threshold 

5703 outfile.write(f"{indent * 1}<g id='alpha'>\n") 

5704 x1, x2 = x_axis_x1, x_axis_x2 

5705 y = linear_scale(alpha_log10, min_log10, max_log10, y_axis_y2, y_axis_y1) 

5706 line = f"{indent * 2}<path d='M {x1} {y} H {x2}' style='stroke:grey;stroke-width:1;fill:none' stroke-dasharray='4 4'/>" 

5707 outfile.write(line + "\n") 

5708 outfile.write(f"{indent * 1}</g>\n") 

5709 

5710 outfile.write("</svg>" + "\n") 

5711 

5712 png_path = os.path.join(output_dir, f"{prefix}plot.png") 

5713 logging.info(f"Rendering output png ({width}x{height}): {png_path}") 

5714 cairosvg.svg2png(url=svg_path, write_to=png_path, output_width=width, output_height=height, scale=png_scale) 

5715 

5716 return svg_path 

5717 

5718 

5719def heatmap( 

5720 tree: str=None, 

5721 tree_format: str="newick", 

5722 gwas: str=None, 

5723 rtab: str=None, 

5724 outdir: str = ".", 

5725 prefix: str=None, 

5726 focal: str=None, 

5727 min_score: float = None, 

5728 tree_width=100, 

5729 margin=20, 

5730 root_branch=10, 

5731 tip_pad=10, 

5732 font_size=16, 

5733 font_family="Roboto", 

5734 png_scale=2.0, 

5735 heatmap_scale=1.5, 

5736 palette={"presence_absence": "#140d91", "structural": "#7a0505", "snp": "#0e6b07", "unknown": "#636362"}, 

5737 args: str = None, 

5738 ): 

5739 """ 

5740 Plot a tree and/or a heatmap of variants. 

5741 

5742 Takes as input a newick tree and/or a table of variants. The table can be either 

5743 an Rtab file, or the locus effects TSV output from the gwas subcommand. 

5744 If both a tree and a table are provided, the tree will determine the sample order  

5745 and arrangement. If just a table is provided, sample order will follow the  

5746 order of the sample columns. A TXT of focal sample IDs can also be supplied 

5747 with one sample ID per line. Outputs a plot in SVG and PNG format. 

5748 

5749 >>> plot(tree="tree.rooted.nwk") 

5750 >>> plot(rtab="combine.Rtab") 

5751 >>> plot(gwas="resistant.locus_effects.significant.tsv") 

5752 >>> plot(tree="tree.rooted.nwk", rtab="combine.Rtab", focal="focal.txt") 

5753 >>> plot(tree="tree.rooted.nwk", gwas="resistant.locus_effects.significant.tsv") 

5754 >>> plot(tree="tree.rooted.nwk, tree_width=500) 

5755 

5756 :param tree: Path to newick tree. 

5757 :param gwas: Path to tsv table of locus effects from gwas subcommand. 

5758 :param rtab: Path to Rtab file of variants. 

5759 :param prefix: Output prefix. 

5760 :param focal: Path to text file of focal sample IDs to highlight. 

5761 :param min_score: Filter GWAS variants for a minimum score. 

5762 :param tree_width: Width in pixels of tree. 

5763 :param margin: Size in pixels of plot margins. 

5764 :param root_branch: Width in pixels of root branch. 

5765 :param tip_pad: Pad in pixels between tip labels and branches. 

5766 :param font_size: Font size. 

5767 :param font_family: Font family. 

5768 :param png_scale: Float that adjusts png scale relative to svg. 

5769 :param heatmap_scale: Float that adjusts heatmap box scale relative to text. 

5770 :param palette: Dict of variant types to colors. 

5771 """ 

5772 

5773 tree_path, gwas_path, focal_path, rtab_path, output_dir = tree, gwas, focal, rtab, outdir 

5774 prefix = f"{prefix}." if prefix != None else "" 

5775 

5776 if not tree_path and not gwas_path and not rtab_path: 

5777 msg = "Either a tree (--tree), a GWAS table (--gwas) or an Rtab --rtab) must be supplied." 

5778 logging.error(msg) 

5779 raise Exception(msg) 

5780 

5781 elif gwas_path and rtab_path: 

5782 msg = "A GWAS table (--gwas) is mutually exclusive with an Rtab (--rtab) file." 

5783 logging.error(msg) 

5784 raise Exception(msg) 

5785 

5786 # Check output directory 

5787 check_output_dir(output_dir) 

5788 

5789 logging.info(f"Importing cairo.") 

5790 import cairosvg 

5791 

5792 # Todo: If we want to color by beta/p-value 

5793 # logging.info(f"Importing matplotlib color palettes.") 

5794 # import matplotlib as mpl 

5795 # import matplotlib.pyplot as plt 

5796 # from matplotlib.colors import rgb2hex 

5797 

5798 # ------------------------------------------------------------------------- 

5799 # Parse Focal Samples 

5800 

5801 focal = [] 

5802 if focal_path: 

5803 logging.info(f"Parsing focal samples: {focal_path}") 

5804 with open(focal_path) as infile: 

5805 for line in infile: 

5806 sample = line.strip() 

5807 if sample != "" and sample not in focal: 

5808 focal.append(sample) 

5809 

5810 # ------------------------------------------------------------------------- 

5811 # Parse Input Tree 

5812 

5813 if tree_path: 

5814 logging.info(f"Parsing {tree_format} tree: {tree_path}") 

5815 tree = read_tree(tree=tree_path, tree_format=tree_format) 

5816 tree.is_rooted = True 

5817 tree.ladderize() 

5818 else: 

5819 root_branch = 0 

5820 tree_width = 0 

5821 

5822 # ------------------------------------------------------------------------- 

5823 # Node labels 

5824 

5825 node_labels = {} 

5826 node_labels_wmax = 0 

5827 node_labels_wmax_text = None 

5828 node_labels_hmax = 0 

5829 node_labels_hmax_text = None 

5830 all_samples = [] 

5831 tree_tips = [] 

5832 tree_nodes = [] 

5833 

5834 # Option 1: Tips from the tree 

5835 if tree_path: 

5836 # Cleanup up labels, dendropy sometimes inserts quotations 

5837 logging.info(f"Parsing tree labels.") 

5838 node_i = 0 

5839 for node in list(tree.preorder_node_iter()): 

5840 # Tip labels 

5841 if node.taxon: 

5842 node.label = str(node.taxon) 

5843 logging.debug(f"Tree Tip: {node.label}") 

5844 # Internal node labels 

5845 else: 

5846 if not node.label or "NODE_" not in node.label: 

5847 if node_i == 0: 

5848 logging.warning(f"Using 'NODE_<i> nomenclature for internal node labels.") 

5849 node.label = f"NODE_{node_i}" 

5850 node_i += 1 

5851 node.label = node.label.replace("\"", "").replace("'", "") 

5852 

5853 # Keep a list of tip labels 

5854 tree_tips = [n.label for n in tree.leaf_node_iter()] 

5855 tree_nodes = [n.label for n in tree.preorder_node_iter()] 

5856 all_samples = copy.deepcopy(tree_tips) 

5857 

5858 # Option 2: Tips from GWAS columns 

5859 if gwas_path: 

5860 logging.info(f"Parsing gwas labels: {gwas_path}") 

5861 with open(gwas_path) as infile: 

5862 header = infile.readline().strip().split("\t") 

5863 # Sample IDs come after the dbxref_alt column (if we used summarized clusters) 

5864 # otherwise they come after the "score" column 

5865 if "dbxref_alt" in header: 

5866 samples = header[header.index("dbxref_alt")+1:] 

5867 else: 

5868 samples = header[header.index("score")+1:] 

5869 for sample in samples: 

5870 if sample not in all_samples: 

5871 all_samples.append(sample) 

5872 

5873 # Option 3: Tips from Rtab columns 

5874 if rtab_path: 

5875 logging.info(f"Parsing rtab labels: {rtab_path}") 

5876 with open(rtab_path) as infile: 

5877 header = infile.readline().strip().split("\t") 

5878 # Sample IDs are all columns after the first 

5879 samples = header[1:] 

5880 for sample in samples: 

5881 if sample not in all_samples: 

5882 all_samples.append(sample) 

5883 

5884 # Figure out the largest sample label, we'll need to know 

5885 # this to position elements around it. 

5886 logging.info(f"Calculating sample label dimensions.") 

5887 

5888 for node in tqdm(all_samples): 

5889 label = text_to_path(text=node, size=font_size, family=font_family) 

5890 node_labels[node] = label 

5891 w, h = label["w"], label["h"] 

5892 if w > node_labels_wmax: 

5893 node_labels_wmax = w 

5894 node_labels_wmax_text = node 

5895 if h > node_labels_hmax: 

5896 node_labels_hmax = h 

5897 node_labels_hmax_text = node 

5898 

5899 node_labels_wmax = math.ceil(node_labels_wmax) 

5900 node_labels_hmax = math.ceil(node_labels_hmax) 

5901 

5902 logging.info(f"The widest sample label is {math.ceil(node_labels_wmax)}px: {node_labels_wmax_text}") 

5903 logging.info(f"The tallest sample label is {math.ceil(node_labels_hmax)}px: {node_labels_hmax_text}") 

5904 

5905 # ------------------------------------------------------------------------- 

5906 # Variant Labels 

5907 

5908 # Figure out the largest variant label for the heatmap 

5909 

5910 variant_labels_wmax = 0 

5911 variant_labels_wmax_text = "" 

5912 variants = OrderedDict() 

5913 

5914 # Option 1. Variants from GWAS rows 

5915 

5916 # The score can range from -1 to +1 

5917 # We will make a gradient palette based on the score for the fill 

5918 score_min = None 

5919 score_max = None 

5920 if gwas_path: 

5921 logging.info(f"Parsing GWAS variants: {gwas_path}") 

5922 with open(gwas_path) as infile: 

5923 header = [line.strip() for line in infile.readline().split("\t")] 

5924 for line in infile: 

5925 row = [l.strip() for l in line.split("\t")] 

5926 data = {k:v for k,v in zip(header, row)} 

5927 variant = f"{data['variant']}" 

5928 if "score" in data: 

5929 score = float(data["score"]) 

5930 if min_score != None and score < min_score: 

5931 logging.info(f"Excluding {variant} due to score {score} < {min_score}") 

5932 continue 

5933 if score < 0: 

5934 logging.warning(f"Converting {variant} score of {score} to 0.00") 

5935 score = 0.0 

5936 data["score"] = score 

5937 score_min = score if score_min == None else min(score_min, score) 

5938 score_max = score if score_max == None else max(score_max, score) 

5939 variants[variant] = {"data": data } 

5940 

5941 # Option 2. Variants from Rtab rows 

5942 elif rtab_path: 

5943 logging.info(f"Parsing Rtab variants: {rtab_path}") 

5944 with open(rtab_path) as infile: 

5945 header = infile.readline().strip().split("\t") 

5946 lines = infile.readlines() 

5947 for line in tqdm(lines): 

5948 row = [l.strip() for l in line.split("\t")] 

5949 variant = f"{row[0]}" 

5950 data = {col:val for col,val in zip(header,row)} 

5951 variants[variant] = {"data": data } 

5952 

5953 # Render variant label text to svg paths 

5954 variant_types = set() 

5955 if len(variants) > 0: 

5956 logging.info(f"Calculating variant label dimensions.") 

5957 for variant in tqdm(variants): 

5958 try: 

5959 variant_type = variant.split("|")[1].split(":")[0] 

5960 except IndexError: 

5961 variant_type = "unknown" 

5962 # Space out content around the pipe delim 

5963 text = variant.replace("|", " | ").replace("presence_absence", "present") 

5964 label = text_to_path(text, size=font_size, family=font_family) 

5965 w = label["w"] 

5966 if w > variant_labels_wmax: 

5967 variant_labels_wmax = w 

5968 variant_labels_wmax_text = text 

5969 variants[variant]["text"] = text 

5970 variants[variant]["type"] = variant_type 

5971 variants[variant]["label"] = label 

5972 variant_types.add(variant_type) 

5973 variants[variant]["fill"] = palette[variant_type] 

5974 

5975 logging.info(f"Grouping variants by type.") 

5976 variants_order = OrderedDict() 

5977 for variant_type in palette: 

5978 for variant,variant_data in variants.items(): 

5979 if variant_type != variant_data["type"]: continue 

5980 variants_order[variant] = variant_data 

5981 

5982 variants = variants_order 

5983 

5984 if variant_labels_wmax_text: 

5985 logging.info(f"The widest variant label is {math.ceil(variant_labels_wmax)}px: {variant_labels_wmax_text}") 

5986 # Order variant types by palette 

5987 variant_types = [v for v in palette if v in variant_types] 

5988 

5989 # ------------------------------------------------------------------------- 

5990 # Initialize plot data 

5991 

5992 plot_data = OrderedDict() 

5993 

5994 # If tree provided, orient plot around it 

5995 if tree_path: 

5996 for node in tree.preorder_node_iter(): 

5997 plot_data[node.label] = { 

5998 "x": None, 

5999 "y": None, 

6000 "parent": node.parent_node.label if node.parent_node else None, 

6001 "children": [c.label for c in node.child_nodes()], 

6002 "label": node_labels[node.label] if node.label in node_labels else OrderedDict(), 

6003 "heatmap": OrderedDict() 

6004 } 

6005 

6006 # Check if any variant samples are missing from the tree tips 

6007 for node in all_samples: 

6008 if node not in plot_data: 

6009 plot_data[node] = { 

6010 "x": None, 

6011 "y": None, 

6012 "parent": None, 

6013 "children": [], 

6014 "label": node_labels[node] if node in node_labels else OrderedDict(), 

6015 "heatmap": OrderedDict() 

6016 } 

6017 

6018 # ------------------------------------------------------------------------- 

6019 # Node coordinates 

6020 

6021 logging.info(f"Calculating node coordinates.") 

6022 

6023 # Identify the branch that sticks out the most, this will 

6024 # determine tip label placement 

6025 node_xmax = 0 

6026 node_i = 0 

6027 node_xmax_text = node_labels_wmax_text 

6028 

6029 # X Coordinates: Distance to the root 

6030 if tree_path: 

6031 for node,dist in zip( 

6032 tree.preorder_node_iter(), 

6033 tree.calc_node_root_distances(return_leaf_distances_only=False) 

6034 ): 

6035 plot_data[node.label]["x"] = dist 

6036 if dist > node_xmax: 

6037 node_xmax = dist 

6038 node_xmax_text = node.label 

6039 

6040 logging.info(f"The most distant tree node is: {node_xmax_text}") 

6041 

6042 # Y Coordinates: Start with tips 

6043 for node in list(tree.leaf_nodes()): 

6044 plot_data[node.label]["y"] = node_i 

6045 node_i += 1 

6046 

6047 # Y Coordinates: Place internal nodes at the midpoint of (immediate?) children 

6048 for node in tree.postorder_internal_node_iter(): 

6049 children = [c.label for c in node.child_nodes()] 

6050 children_y = [plot_data[c]["y"] for c in children] 

6051 y = sum(children_y) / len(children) 

6052 plot_data[node.label]["y"] = y 

6053 

6054 # Set coords for non-tree samples 

6055 for node in plot_data: 

6056 if plot_data[node]["x"] == None: 

6057 plot_data[node]["x"] = 0 

6058 plot_data[node]["y"] = node_i 

6059 node_i += 1 

6060 

6061 # ------------------------------------------------------------------------- 

6062 # Coordinate scaling 

6063 

6064 # Rescale coordinates to pixels 

6065 # X Scaling: Based on the user's requested tree width 

6066 # Y Scaling: Based on the number of nodes in the tree 

6067 

6068 logging.info(f"Rescaling branch lengths to pixels.") 

6069 node_xmax = max([c["x"] for c in plot_data.values()]) 

6070 node_ymax = max([c["y"] for c in plot_data.values()]) 

6071 

6072 # Heatmap dimensions for scaling 

6073 heatmap_s = (node_labels_hmax * heatmap_scale) 

6074 heatmap_pad = heatmap_s / 4 

6075 heatmap_r = 2 

6076 heatmap_h = (heatmap_s * len(all_samples)) + (heatmap_pad * (len(all_samples) - 1)) 

6077 logging.info(f"Heatmap boxes will be {heatmap_s}px wide with {heatmap_pad}px of padding.") 

6078 

6079 # Rescale the x coordinates based on pre-defined maximum width 

6080 if node_xmax > 0: 

6081 x_scale = math.floor(tree_width / node_xmax) 

6082 else: 

6083 x_scale = 1 

6084 # Rescale y coordinates based on the heatmap and text label dimensions 

6085 if len(all_samples) > 1: 

6086 tree_h = math.ceil((heatmap_s * len(all_samples)) + (heatmap_pad * (len(all_samples) - 1)) - heatmap_s) 

6087 y_scale = math.floor(tree_h / node_ymax) 

6088 else: 

6089 tree_h = math.ceil((heatmap_s * len(all_samples))) 

6090 y_scale = math.floor(tree_h) 

6091 

6092 # Rescale the node coordinates, and position them 

6093 logging.info(f"Positioning nodes.") 

6094 for node, data in plot_data.items(): 

6095 # Locate the node's x coordinate, based on the distance to root 

6096 plot_data[node]["x"] = math.floor(margin + root_branch + (data["x"] * x_scale)) 

6097 # Locate the node's y coordinate, if heatmap is available 

6098 if variant_labels_wmax > 0: 

6099 y = math.floor(margin + variant_labels_wmax + heatmap_s + (data["y"] * y_scale)) 

6100 else: 

6101 y = math.floor(margin + data["y"] * y_scale) 

6102 plot_data[node]["y"] = y 

6103 

6104 # Get the new X coord where tips start 

6105 node_xmax = plot_data[node_xmax_text]["x"] 

6106 node_labels_x = node_xmax + tip_pad 

6107 

6108 # Rescale the node label coordinates 

6109 logging.info(f"Positioning tip labels.") 

6110 for node in all_samples: 

6111 data = plot_data[node] 

6112 # Adjust the label y position to center on the first glyph 

6113 first_glyph = data["label"]["glyphs"][0] 

6114 first_glyph_h = first_glyph["h"] 

6115 label_y = data["y"] + (first_glyph_h / 2) 

6116 plot_data[node]["label"]["y"] = label_y 

6117 label_w = plot_data[node]["label"]["w"] 

6118 for i,glyph in enumerate(data["label"]["glyphs"]): 

6119 # Left alignment 

6120 #plot_data[node]["label"]["glyphs"][i]["x"] = node_labels_x + glyph["x"] 

6121 # Right alignment 

6122 plot_data[node]["label"]["glyphs"][i]["x"] = node_labels_x + (node_labels_wmax - label_w) + glyph["x"] 

6123 plot_data[node]["label"]["glyphs"][i]["y"] = label_y 

6124 

6125 # Heatmap labels 

6126 logging.info(f"Positioning variant labels.") 

6127 heatmap_x = node_labels_x + node_labels_wmax + tip_pad 

6128 heatmap_w = 0 

6129 

6130 # Add spaces between the different variant types 

6131 prev_var_type = None 

6132 prev_x = None 

6133 

6134 for variant in variants: 

6135 # Adjust the position to center on the first glyph 

6136 data = variants[variant]["data"] 

6137 variant_type = variants[variant]["type"] 

6138 glyphs = variants[variant]["label"]["glyphs"] 

6139 first_glyph = glyphs[0] 

6140 first_glyph_h = first_glyph["h"] 

6141 first_glyph_offset = first_glyph_h + ((heatmap_s - first_glyph_h)/ 2) 

6142 

6143 # Set the absolute position of the label 

6144 if prev_x == None: 

6145 x = heatmap_x + first_glyph_offset 

6146 elif prev_var_type != None and variant_type != prev_var_type: 

6147 x += heatmap_s + (heatmap_pad * 3) 

6148 else: 

6149 x = prev_x + (heatmap_s + heatmap_pad) 

6150 

6151 # Set the absolute position of the box 

6152 variants[variant]["box"] = {"x": x - first_glyph_offset} 

6153 prev_x = x 

6154 prev_var_type = variant_type 

6155 # Add an extra 2 pixels to make it reach just beyond the final box 

6156 heatmap_w = (2 + heatmap_s + x - first_glyph_offset) - heatmap_x 

6157 

6158 y = margin + variant_labels_wmax 

6159 variants[variant]["label"]["x"] = x 

6160 variants[variant]["label"]["y"] = y 

6161 for i_g,g in enumerate(glyphs): 

6162 variants[variant]["label"]["glyphs"][i_g]["x"] = x + g["x"] 

6163 variants[variant]["label"]["glyphs"][i_g]["y"] = y 

6164 

6165 # ------------------------------------------------------------------------- 

6166 # Heatmap Table Data 

6167 

6168 score_min_alpha, score_max_alpha = 0.10, 1.0 

6169 

6170 logging.info(f"Positioning heatmap boxes.") 

6171 for i,variant in enumerate(variants): 

6172 variant_type = variants[variant]["type"] 

6173 data = variants[variant]["data"] 

6174 label = variants[variant]["label"] 

6175 x = variants[variant]["box"]["x"] 

6176 for node in all_samples: 

6177 # Ex. a sample in tree but not heatmap 

6178 if node not in data: continue 

6179 v = data[node] 

6180 v = int(v) if v.isdigit() else str(v) 

6181 fill_opacity = 1.0 

6182 # Light grey for missing values (".") 

6183 score = data["score"] if "score" in data else None 

6184 if v == 1: 

6185 box_fill = variants[variant]["fill"] 

6186 # Use score range if possible 

6187 if score_min != None and "score" in data: 

6188 score = float(data["score"]) 

6189 # Option 1: Fixed range from 0 to 1 

6190 fill_opacity = linear_scale(score, 0, 1.0, score_min_alpha, score_max_alpha) 

6191 # Option 2. Relative range from score_min to score_max 

6192 # fill_opacity = linear_scale(score, score_min, score_max, score_min_alpha, score_max_alpha) 

6193 box_stroke = "black" 

6194 elif v == 0: 

6195 box_fill = "white" 

6196 box_stroke = "black" 

6197 else: 

6198 box_fill = "none" 

6199 box_stroke = "none" 

6200 y = plot_data[node]["y"] - (heatmap_s / 2) 

6201 d = { 

6202 "v": v, "x": x, "y": y, "w": heatmap_s, "h": heatmap_s, 

6203 "r": heatmap_r, "fill": box_fill, "fill_opacity": fill_opacity, "stroke": box_stroke, 

6204 "hovertext": data, 

6205 } 

6206 plot_data[node]["heatmap"][variant] = d 

6207 

6208 # ----------------------------------------------------------------------------- 

6209 # Legend 

6210 

6211 legend_x = (heatmap_x + heatmap_w + heatmap_s) 

6212 legend_y = margin + variant_labels_wmax + tip_pad 

6213 legend_w = 0 

6214 legend_h = 0 

6215 legend = OrderedDict() 

6216 legend_labels_hmax = 0 

6217 legend_labels_wmax = 0 

6218 legend_title_wmax = 0 

6219 legend_title_hmax = 0 

6220 

6221 if score_min != None and score_max != None: 

6222 logging.info(f"Drawing legend at: {legend_x}, {legend_y}") 

6223 legend_h = (heatmap_s * 3) + (heatmap_pad * 2) 

6224 # TBD: Think about negative score handling? 

6225 # Option 1. Fixed Values 

6226 legend_values = ["0.00", "0.25", "0.50", "0.75", "1.00"] 

6227 # # Option 2. Relative to score max 

6228 # interval = score_max / 4 

6229 # legend_values = [round(i * interval,2) for i in range(0,4)] + [round(score_max, 2)] 

6230 # legend_values = ["{:.2f}".format(v) for v in legend_values] 

6231 

6232 # Tick labels: convert to path 

6233 tick_labels = [text_to_path(text, size=font_size * 0.50, family=font_family) for text in legend_values] 

6234 for label in tick_labels: 

6235 legend_labels_hmax = max(legend_labels_hmax, label["h"]) 

6236 legend_labels_wmax = max(legend_labels_wmax, label["w"]) 

6237 title_label = text_to_path("Score", size=font_size * 0.75, family=font_family) 

6238 legend_title_hmax = max(legend_title_hmax, label["h"]) 

6239 legend_title_wmax = max(legend_title_wmax, label["w"]) 

6240 

6241 

6242 # Width of legend + tick len + pad + label wmax + pad 

6243 entry_w = heatmap_s + heatmap_pad + (heatmap_pad / 2) + legend_labels_wmax + heatmap_s 

6244 if (legend_title_wmax + heatmap_s) > entry_w: 

6245 entry_w = legend_title_wmax + heatmap_s 

6246 

6247 for i_v, variant_type in enumerate(variant_types): 

6248 box = { 

6249 "x": legend_x + (i_v * entry_w), 

6250 "y": legend_y, 

6251 "w": heatmap_s, 

6252 "h": (heatmap_s * 3) + (heatmap_pad * 2), 

6253 "fill": f"url(#{variant_type}_gradient)", 

6254 } 

6255 title = copy.deepcopy(title_label) 

6256 title["x"] = box["x"] 

6257 title["y"] = box["y"] - tip_pad 

6258 for i_g,g in enumerate(title["glyphs"]): 

6259 title["glyphs"][i_g]["x"] = title["x"] + g["x"] 

6260 title["glyphs"][i_g]["y"] = title["y"] 

6261 ticks = [] 

6262 for i_t,text in enumerate(reversed(legend_values)): 

6263 x1 = box["x"] + box["w"] 

6264 x2 = x1 + heatmap_pad 

6265 y = legend_y + (i_t * (legend_h / (len(legend_values) - 1))) 

6266 label = copy.deepcopy(list(reversed(tick_labels))[i_t]) 

6267 label["x"] = x2 + (heatmap_pad / 2) 

6268 label["y"] = y 

6269 

6270 # Adjust the label y position to center on first numeric char 

6271 center_glyph_h = label["glyphs"][0]["h"] 

6272 label["y"] = y + (center_glyph_h / 2) 

6273 for i_g,g in enumerate(label["glyphs"]): 

6274 label["glyphs"][i_g]["x"] = label["x"] + g["x"] 

6275 label["glyphs"][i_g]["y"] = label["y"] 

6276 

6277 tick = {"text": text, "x1": x1, "x2":x2, "y": y, "label": copy.deepcopy(label) } 

6278 ticks.append(tick) 

6279 

6280 legend[variant_type] = {"title": title, "box": box, "ticks": ticks} 

6281 legend_w += entry_w 

6282 

6283 # ----------------------------------------------------------------------------- 

6284 # Draw Tree 

6285 

6286 logging.info(f"Calculating final image dimensions.") 

6287 # Final image dimensions 

6288 if len(legend) > 0: 

6289 logging.info("Final dimensions with legend.") 

6290 width = math.ceil(legend_x + legend_w + margin) 

6291 height = math.ceil(margin + variant_labels_wmax + tip_pad + tree_h + heatmap_s + margin) 

6292 elif heatmap_w > 0 and heatmap_h > 0: 

6293 logging.info("Final dimensions with heatmap.") 

6294 width = math.ceil(heatmap_x + heatmap_w + margin) 

6295 height = math.ceil(margin + variant_labels_wmax + tip_pad + tree_h + heatmap_s + margin) 

6296 else: 

6297 logging.info("Final dimensions without heatmap.") 

6298 width = math.ceil(node_labels_x + node_labels_wmax + margin) 

6299 height = math.ceil(tree_h + (2*margin)) 

6300 

6301 svg_path = os.path.join(output_dir, f"{prefix}plot.svg") 

6302 logging.info(f"Rendering output svg ({width}x{height}): {svg_path}") 

6303 with open(svg_path, 'w') as outfile: 

6304 header = textwrap.dedent( 

6305 f"""\ 

6306 <svg  

6307 version="1.1" 

6308 xmlns="http://www.w3.org/2000/svg" 

6309 xmlns:xlink="http://www.w3.org/1999/xlink" 

6310 preserveAspectRatio="xMidYMid meet" 

6311 width="{width}" 

6312 height="{height}" 

6313 viewbox="0 0 {width} {height}"> 

6314 """) 

6315 

6316 outfile.write(header.strip() + "\n") 

6317 

6318 branches = [] 

6319 tip_circles = [] 

6320 tip_dashes = [] 

6321 tip_labels = [] 

6322 variant_labels = OrderedDict() 

6323 heatmap_boxes = OrderedDict() 

6324 focal_boxes = OrderedDict() 

6325 legend_entries = OrderedDict() 

6326 

6327 # Variant Heatmap Labels 

6328 for variant in variants: 

6329 label = variants[variant]["label"] 

6330 rx, ry = label["x"], label["y"] 

6331 if variant not in variant_labels: 

6332 variant_labels[variant] = [] 

6333 heatmap_boxes[variant] = [] 

6334 for g in label["glyphs"]: 

6335 line = f"<path transform='rotate(-90, {rx}, {ry}) translate({g['x']},{g['y']})' d='{g['d']}' />" 

6336 variant_labels[variant].append(line) 

6337 

6338 for node,data in plot_data.items(): 

6339 logging.debug(f"node: {node}, data: {data}") 

6340 

6341 # Root branch, add custom little starter branch 

6342 if not data["parent"]: 

6343 cx, cy = data['x'], data['y'] 

6344 px, py = data['x'] - root_branch, cy 

6345 # Non-root 

6346 else: 

6347 p_data = plot_data[data["parent"]] 

6348 cx, cy, px, py = data['x'], data['y'], p_data['x'], p_data['y'] 

6349 

6350 # Draw tree lines 

6351 if node in tree_nodes: 

6352 line = f"<path d='M {cx} {cy} H {px} V {py}' style='stroke:black;stroke-width:2;fill:none'/>" 

6353 branches.append(line) 

6354 

6355 # Focal box 

6356 if node in focal: 

6357 # Draw the box to the midway point between it and the next label 

6358 if len(plot_data) == 1: 

6359 rh = node_labels_hmax 

6360 else: 

6361 # Figure out adjacent tip 

6362 tip_i = all_samples.index(node) 

6363 adjacent_tip = all_samples[tip_i+1] if tip_i == 0 else all_samples[tip_i-1] 

6364 rh = abs(plot_data[adjacent_tip]["y"] - data['y']) 

6365 ry = data['y'] - (rh /2) 

6366 rx = node_labels_x 

6367 rw = node_labels_wmax + tip_pad + heatmap_w 

6368 

6369 rect = f"<rect width='{rw}' height='{rh}' x='{rx}' y='{ry}' rx='1' ry='1' style='fill:grey;fill-opacity:.40'/>" 

6370 focal_boxes[node] = [rect] 

6371 

6372 # Dashed line: tree -> tip 

6373 if node in tree_tips: 

6374 lx = node_labels_x + (node_labels_wmax - data["label"]["w"]) 

6375 line = f"<path d='M {cx + 4} {cy} H {lx - 4}' style='stroke:grey;stroke-width:1;fill:none' stroke-dasharray='4 4'/>" 

6376 tip_dashes.append(line) 

6377 

6378 # Tip circles 

6379 if node in focal: 

6380 circle = f"<circle cx='{cx}' cy='{cy}' r='4' style='fill:black;stroke:black;stroke-width:1' />" 

6381 tip_circles.append(circle) 

6382 

6383 # Sample Labels 

6384 if node in all_samples: 

6385 for g in data["label"]["glyphs"]: 

6386 line = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />" 

6387 tip_labels.append(line) 

6388 # Heatmap 

6389 if node in all_samples: 

6390 for variant,d in data["heatmap"].items(): 

6391 

6392 hover_text = "\n".join([f"{k}: {v}" for k,v in d["hovertext"].items() if k not in all_samples and not k.endswith("_alt")]) 

6393 rect = f"<rect width='{d['w']}' height='{d['h']}' x='{d['x']}' y='{d['y']}' rx='{d['r']}' ry='{d['r']}' style='fill:{d['fill']};fill-opacity:{d['fill_opacity']};stroke-width:1;stroke:{d['stroke']}'>" 

6394 title = f"<title>{hover_text}</title>" 

6395 heatmap_boxes[variant] += [rect, title, "</rect>"] 

6396 

6397 # Legend 

6398 for variant_type in legend: 

6399 legend_entries[variant_type] = [] 

6400 # Gradient 

6401 fill = palette[variant_type] 

6402 gradient = f""" 

6403 <linearGradient id="{variant_type}_gradient" x1="0" x2="0" y1="0" y2="1" gradientTransform="rotate(180 0.5 0.5)"> 

6404 <stop offset="0%" stop-color="{fill}" stop-opacity="{score_min_alpha}" /> 

6405 <stop offset="100%" stop-color="{fill}" stop-opacity="{score_max_alpha}" /> 

6406 </linearGradient>""" 

6407 legend_entries[variant_type].append(gradient) 

6408 # Legend Title 

6409 title = legend[variant_type]["title"] 

6410 for g in title["glyphs"]: 

6411 label = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />" 

6412 legend_entries[variant_type].append(label) 

6413 # Legend box 

6414 d = legend[variant_type]["box"] 

6415 rect = f"<rect width='{d['w']}' height='{d['h']}' x='{d['x']}' y='{d['y']}' style='fill:{d['fill']};stroke-width:1;stroke:black'/>" 

6416 legend_entries[variant_type].append(rect) 

6417 # Legend Ticks 

6418 for t in legend[variant_type]["ticks"]: 

6419 # Tick line 

6420 line = f"<path d='M {t['x1']} {t['y']} H {t['x2']} V {t['y']}' style='stroke:black;stroke-width:1;fill:none'/>" 

6421 legend_entries[variant_type].append(line) 

6422 # Tick Label 

6423 for g in t["label"]["glyphs"]: 

6424 label = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />" 

6425 legend_entries[variant_type].append(label) 

6426 

6427 # Draw elements in groups 

6428 indent = " " 

6429 outfile.write(f"{indent * 1}<g id='Plot'>" + "\n") 

6430 

6431 # White canvas background 

6432 background = f"<rect width='{width}' height='{height}' x='0' y='0' style='fill:white;stroke-width:1;stroke:white'/>" 

6433 outfile.write(f"{indent * 2}<g id='Background'>\n{indent * 3}{background}\n{indent * 2}</g>" + "\n") 

6434 

6435 # Focal boxes start 

6436 outfile.write(f"{indent * 2}<g id='Focal'>" + "\n") 

6437 # Focal Boxes 

6438 outfile.write(f"{indent * 3}<g id='Focal Boxes'>\n") 

6439 for sample,boxes in focal_boxes.items(): 

6440 outfile.write(f"{indent * 4}<g id='{sample}'>\n") 

6441 outfile.write("\n".join([f"{indent * 5}{b}" for b in boxes]) + "\n") 

6442 outfile.write(f"{indent * 4}</g>\n") 

6443 outfile.write(f"{indent * 3}</g>\n") 

6444 # Focal boxes end 

6445 outfile.write(f"{indent * 2}</g>\n") 

6446 

6447 # Tree Start 

6448 outfile.write(f"{indent * 2}<g id='Tree'>" + "\n") 

6449 # Branches 

6450 outfile.write(f"{indent * 3}<g id='Branches'>\n") 

6451 outfile.write("\n".join([f"{indent * 4}{b}" for b in branches]) + "\n") 

6452 outfile.write(f"{indent * 3}</g>\n") 

6453 # Tip Circles 

6454 outfile.write(f"{indent * 3}<g id='Tip Circles'>\n") 

6455 outfile.write("\n".join([f"{indent * 4}{c}" for c in tip_circles]) + "\n") 

6456 outfile.write(f"{indent * 3}</g>\n") 

6457 # Tip Dashes 

6458 outfile.write(f"{indent * 3}<g id='Tip Dashes'>\n") 

6459 outfile.write("\n".join([f"{indent * 4}{d}" for d in tip_dashes]) + "\n") 

6460 outfile.write(f"{indent * 3}</g>\n") 

6461 # Tip Labels 

6462 outfile.write(f"{indent * 3}<g id='Tip Labels'>\n") 

6463 outfile.write("\n".join([f"{indent * 4}{t}" for t in tip_labels]) + "\n") 

6464 outfile.write(f"{indent * 3}</g>\n") 

6465 # Tree End 

6466 outfile.write(f"{indent}</g>\n") 

6467 

6468 # Heatmap Start 

6469 outfile.write(f"{indent}<g id='Heatmap'>" + "\n") 

6470 # Variant Labels 

6471 outfile.write(f"{indent * 3}<g id='Variant Labels'>\n") 

6472 for variant,paths in variant_labels.items(): 

6473 outfile.write(f"{indent * 4}<g id='{variant}'>\n") 

6474 outfile.write("\n".join([f"{indent * 5}{p}" for p in paths]) + "\n") 

6475 outfile.write(f"{indent * 4}</g>\n") 

6476 outfile.write(f"{indent * 3}</g>\n") 

6477 # Heatmap Boxes 

6478 outfile.write(f"{indent * 3}<g id='Heatmap Boxes'>\n") 

6479 for variant,boxes in heatmap_boxes.items(): 

6480 outfile.write(f"{indent * 4}<g id='{variant}'>\n") 

6481 outfile.write("\n".join([f"{indent * 5}{b}" for b in boxes]) + "\n") 

6482 outfile.write(f"{indent * 4}</g>\n") 

6483 outfile.write(f"{indent * 3}</g>\n") 

6484 # Heatmap End 

6485 outfile.write(f"{indent * 2}</g>\n") 

6486 

6487 # Legend 

6488 if score_min != None and score_max != None: 

6489 outfile.write(f"{indent * 2}<g id='Legend'>\n") 

6490 for variant_type in legend_entries: 

6491 outfile.write(f"{indent * 3}<g id='{variant_type}'>\n") 

6492 for element in legend_entries[variant_type]: 

6493 outfile.write(f"{indent * 4}{element}\n") 

6494 outfile.write(f"{indent * 3}</g>\n") 

6495 outfile.write(f"{indent * 4}</g>\n") 

6496 

6497 # outfile.write(f"{indent * 3}<g id='{variant_type}'>") 

6498 # fill = palette[variant_type] 

6499 # gradient = f""" 

6500 # <linearGradient id="{variant_type}_gradient" x1="0" x2="0" y1="0" y2="1" gradientTransform="rotate(180 0.5 0.5)"> 

6501 # <stop offset="0%" stop-color="{fill}" stop-opacity="{score_min_alpha}" /> 

6502 # <stop offset="100%" stop-color="{fill}" stop-opacity="{score_max_alpha}" /> 

6503 # </linearGradient>""" 

6504 # outfile.write(f"{indent * 4}{gradient}\n") 

6505 # outfile.write(f"{indent * 3}</g>\n") 

6506 # # Box 

6507 # rect = f"<rect width='{legend_w}' height='{legend_h}' x='{legend_x}' y='{legend_y}' style='stroke:black;stroke-width:1;fill:url(#LegendGradient)'/>" 

6508 # outfile.write(f"{indent * 3}{rect}\n") 

6509 # # Ticks 

6510 # for variant_type in legend_ticks: 

6511 # for t in legend_ticks[variant_type]: 

6512 # line = f"<path d='M {t['x1']} {t['y']} H {t['x2']} V {t['y']}' style='stroke:black;stroke-width:1;fill:none'/>" 

6513 # outfile.write(f"{indent * 3}{line}\n") 

6514 # # Label 

6515 # for g in t["label"]["glyphs"]: 

6516 # glyph = f"<path transform='translate({g['x']},{g['y']})' d='{g['d']}' />" 

6517 # outfile.write(f"{indent * 3}{glyph}\n") 

6518 

6519 # Close out the plot group 

6520 outfile.write(f"{indent * 1}</g>\n") 

6521 outfile.write("</svg>") 

6522 

6523 png_path = os.path.join(output_dir, f"{prefix}plot.png") 

6524 logging.info(f"Rendering output png ({width}x{height}): {png_path}") 

6525 cairosvg.svg2png(url=svg_path, write_to=png_path, output_width=width, output_height=height, scale=png_scale) 

6526 

6527 return svg_path 

6528 

6529def cli(args:str=None): 

6530 

6531 # Parse input args from system or function input 

6532 sys_argv_original = sys.argv 

6533 if args != None: 

6534 sys.argv = f"pangwas {args}".split(" ") 

6535 command = " ".join(sys.argv) 

6536 if len(sys.argv) == 1: 

6537 sys.argv = ["pangwas", "--help"] 

6538 

6539 # Handle sys exits by argparse more gracefully 

6540 try: 

6541 options = get_options() 

6542 except SystemExit as exit_status: 

6543 if f"{exit_status}" == "0": 

6544 return 0 

6545 elif f"{exit_status}" == "2": 

6546 return exit_status 

6547 else: 

6548 msg = f"pangwas cli exited with status: {exit_status}" 

6549 logging.error(msg) 

6550 return exit_status 

6551 

6552 # Restore original argv 

6553 sys.argv = sys_argv_original 

6554 

6555 if options.version: 

6556 import importlib.metadata 

6557 version = importlib.metadata.version("pangwas") 

6558 print(f"pangwas v{version}") 

6559 return 0 

6560 

6561 logging.info("Begin") 

6562 logging.info(f"Command: {command}") 

6563 

6564 # Organize options as kwargs for functions 

6565 kwargs = {k:v for k,v in vars(options).items() if k not in ["version", "subcommand"]} 

6566 try: 

6567 fn = globals()[options.subcommand] 

6568 fn(**kwargs) 

6569 except KeyError: 

6570 logging.error(f"A pangwas subcommand is required (ex. pangwas extract).") 

6571 return 1 

6572 

6573 logging.info("Done") 

6574 

6575if __name__ == "__main__": 

6576 cli()