1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package org.utgenome.util.repeat;
24
25 import java.io.BufferedOutputStream;
26 import java.io.BufferedReader;
27 import java.io.BufferedWriter;
28 import java.io.File;
29 import java.io.FileOutputStream;
30 import java.io.FileReader;
31 import java.io.FileWriter;
32 import java.io.IOException;
33 import java.io.Writer;
34 import java.util.ArrayList;
35 import java.util.Collections;
36 import java.util.Comparator;
37 import java.util.Date;
38 import java.util.List;
39 import java.util.Set;
40 import java.util.TreeMap;
41
42 import org.utgenome.UTGBErrorCode;
43 import org.utgenome.UTGBException;
44 import org.utgenome.format.fasta.CompactACGT;
45 import org.utgenome.format.fasta.FASTA;
46 import org.utgenome.gwt.utgb.client.bio.Interval;
47 import org.utgenome.gwt.utgb.client.canvas.IntervalTree;
48 import org.xerial.lens.Lens;
49 import org.xerial.silk.SilkWriter;
50 import org.xerial.util.ObjectHandlerBase;
51 import org.xerial.util.graph.AdjacencyList;
52 import org.xerial.util.graph.Edge;
53 import org.xerial.util.log.Logger;
54 import org.xerial.util.opt.Argument;
55 import org.xerial.util.opt.Option;
56 import org.xerial.util.opt.OptionParser;
57 import org.xerial.util.opt.OptionParserException;
58 import org.xerial.util.text.TabAsTreeParser;
59
60
61
62
63
64
65
66 public class RepeatChainFinder {
67
68 private static Logger _logger = Logger.getLogger(RepeatChainFinder.class);
69
70 @Argument(index = 0)
71 private File intervalFile;
72 @Argument(index = 1)
73 private File fastaFile;
74
75 @Option(symbol = "t", longName = "threshold", description = "threshold for connecting fragments")
76 private int threshold = 50;
77
78 @Option(symbol = "s", description = "sequence name to read from FASTA")
79 private String chr;
80
81 @Option(symbol = "o", description = "output folder")
82 private String outFolder;
83
84
85
86
87
88
89
90 public static class Interval2D extends Interval {
91
92 private static final long serialVersionUID = 1L;
93
94 public int y1;
95 public int y2;
96
97 public Interval2D() {
98 }
99
100 public Interval2D(int x1, int y1, int x2, int y2) {
101 super(x1, x2);
102 this.y1 = y1;
103 this.y2 = y2;
104 }
105
106 public Interval2D(Interval2D first, Interval2D last) {
107 super(first.getStart(), last.getEnd());
108 this.y1 = first.y1;
109 this.y2 = last.y2;
110 }
111
112 public int compareTo(Interval2D other) {
113
114 int diff = getStart() - other.getStart();
115 if (diff != 0)
116 return diff;
117
118 diff = y1 - other.y1;
119 if (diff != 0)
120 return diff;
121
122 diff = getEnd() - other.getEnd();
123 if (diff != 0)
124 return diff;
125
126 diff = y2 - other.y2;
127 if (diff != 0)
128 return diff;
129
130 return 0;
131 }
132
133 public int maxLength() {
134 return Math.max(super.length(), Math.abs(y2 - y1));
135 }
136
137 @Override
138 public String toString() {
139 return String.format("max len:%d, (%d, %d)-(%d, %d)", maxLength(), getStart(), y1, getEnd(), y2);
140 }
141
142 public Interval startPoint() {
143 return new Interval(getStart(), y1);
144 }
145
146 public Interval endPoint() {
147 return new Interval(getEnd(), y2);
148 }
149
150 public boolean isInLowerRightRegion() {
151 return y1 < getStart();
152 }
153
154 public boolean isForward() {
155 return y1 <= y2;
156 }
157
158 public int forwardDistance(Interval2D other) {
159 int xDiff = Math.abs(other.getStart() - this.getEnd());
160 int yDiff = Math.abs(other.y1 - this.y2);
161 if (this.isForward()) {
162 if (other.isForward())
163 return Math.max(xDiff, yDiff);
164 else
165 return -1;
166 }
167 else {
168 if (other.isForward())
169 return -1;
170 else
171 return Math.max(xDiff, yDiff);
172 }
173 }
174
175 public void toDotFormat(Writer out) throws IOException {
176 out.append(String.format("%d\t%d\t%d\t%d\n", getStart(), y1, getEnd(), y2));
177 }
178
179 @Override
180 public boolean equals(Object o) {
181 Interval2D other = Interval2D.class.cast(o);
182 return this.getStart() == other.getStart() && this.getEnd() == other.getEnd() && this.y1 == other.y1 && this.y2 == other.y2;
183 }
184
185 @Override
186 public int hashCode() {
187 int hash = 3;
188 hash += 137 * this.getStart();
189 hash += 137 * this.getEnd();
190 hash += 137 * this.y1;
191 hash += 137 * this.y2;
192 return hash;
193 }
194
195 }
196
197
198
199
200
201
202 public static class FlippedInterval2D extends Interval {
203 private static final long serialVersionUID = 1L;
204 final Interval2D orig;
205
206 public FlippedInterval2D(Interval2D orig) {
207 super(orig.y1, orig.y2);
208 this.orig = orig;
209 }
210
211 }
212
213 public static class IntervalCluster implements Comparable<IntervalCluster> {
214
215 public int id;
216 public final List<Interval2D> component;
217 public final int length;
218
219 public IntervalCluster(List<Interval2D> elements) {
220 this.component = elements;
221 Collections.sort(elements, new Comparator<Interval2D>() {
222 public int compare(Interval2D o1, Interval2D o2) {
223 return o2.maxLength() - o1.maxLength();
224 }
225 });
226
227 int maxLength = -1;
228 for (Interval2D each : elements) {
229 if (maxLength < each.maxLength())
230 maxLength = each.maxLength();
231 }
232 this.length = maxLength;
233 }
234
235 public void setId(int id) {
236 this.id = id;
237 }
238
239 public void validate() throws UTGBException {
240
241 if (component.size() <= 1)
242 return;
243
244 for (Interval2D p : component) {
245 boolean hasOverlap = false;
246 for (Interval2D q : component) {
247 if (p == q)
248 continue;
249 if (p.hasOverlap(q)) {
250 hasOverlap = true;
251 break;
252 }
253 else {
254 Interval py = new Interval(p.y1, p.y2);
255 Interval qy = new Interval(q.y1, q.y2);
256 if (py.hasOverlap(qy)) {
257 hasOverlap = true;
258 break;
259 }
260 }
261 }
262 if (!hasOverlap)
263 throw new UTGBException(UTGBErrorCode.ValidationFailure);
264 }
265 }
266
267 public int compareTo(IntervalCluster o) {
268 return this.length - o.length;
269 }
270
271 public int size() {
272 return component.size();
273 }
274
275 public void toDotFile(Writer out) throws IOException {
276 for (Interval2D each : component) {
277 each.toDotFormat(out);
278 }
279 }
280
281 @Override
282 public String toString() {
283 return String.format("length:%d, size:%d", length, size());
284 }
285
286 }
287
288 public static void main(String[] args) {
289 RepeatChainFinder finder = new RepeatChainFinder();
290 OptionParser opt = new OptionParser(finder);
291 try {
292 opt.parse(args);
293 finder.execute(args);
294 }
295 catch (OptionParserException e) {
296 _logger.error(e);
297 }
298 catch (Exception e) {
299 _logger.error(e);
300 e.printStackTrace(System.err);
301 }
302 }
303
304 final AdjacencyList<Interval2D, Integer> graph = new AdjacencyList<Interval2D, Integer>();
305 final ArrayList<Interval2D> rangeList = new ArrayList<Interval2D>();
306
307 public void execute(String[] args) throws Exception {
308
309 if (intervalFile == null)
310 throw new UTGBException(UTGBErrorCode.MISSING_FILES, "no input file is given");
311
312 if (outFolder == null) {
313 outFolder = String.format("target/cluster-T%d", threshold);
314 new File(outFolder).mkdirs();
315 }
316
317 {
318 int numChain = 0;
319 final List<Interval2D> intervals = new ArrayList<Interval2D>();
320
321 _logger.info("chain threshold: " + threshold);
322
323 _logger.info("loading intervals..");
324
325
326 BufferedReader in = new BufferedReader(new FileReader(intervalFile));
327 try {
328 TabAsTreeParser t = new TabAsTreeParser(in);
329 List<String> label = new ArrayList<String>();
330 label.add("start");
331 label.add("y1");
332 label.add("end");
333 label.add("y2");
334 t.setColunLabel(label);
335 t.setRowNodeName("entry");
336
337 Lens.find(Interval2D.class, "entry", new ObjectHandlerBase<Interval2D>() {
338 public void handle(Interval2D interval) throws Exception {
339 if (!interval.isInLowerRightRegion())
340 intervals.add(interval);
341 }
342
343 @Override
344 public void finish() throws Exception {
345 _logger.info(String.format("loaded %d intervals", intervals.size()));
346 }
347 }, t);
348 }
349 finally {
350 in.close();
351 }
352
353 _logger.info("sorting intervals...");
354
355 Collections.sort(intervals);
356
357 _logger.info("sweeping intervals...");
358 {
359 final IntervalTree<Interval2D> intervalTree = new IntervalTree<Interval2D>();
360 for (Interval2D current : intervals) {
361
362
363 intervalTree.removeBefore(current.getStart() - threshold);
364
365
366 for (Interval2D each : intervalTree) {
367 final int dist = each.forwardDistance(current);
368 if (dist > 0 && dist < threshold) {
369 graph.addEdge(each, current, dist);
370 }
371 }
372
373 intervalTree.add(current);
374 }
375 }
376
377 if (_logger.isTraceEnabled())
378 _logger.trace("graph:\n" + graph.toGraphViz());
379
380
381 _logger.info("chaining...");
382 for (Interval2D node : graph.getNodeLabelSet()) {
383 List<Interval2D> adjacentNodes = new ArrayList<Interval2D>();
384 for (Edge each : graph.getOutEdgeSet(node)) {
385 adjacentNodes.add(graph.getNodeLabel(each.getDestNodeID()));
386 }
387
388 if (_logger.isTraceEnabled())
389 _logger.trace(String.format("node %s -> %s", node, adjacentNodes));
390 }
391
392
393 _logger.info("enumerating connected paths...");
394 for (Interval2D each : graph.getNodeLabelSet()) {
395 if (!graph.getInEdgeSet(each).isEmpty())
396 continue;
397
398
399 findPathsToLeaf(each, each);
400 }
401 _logger.info("# of paths : " + rangeList.size());
402
403
404 {
405
406 TreeMap<Interval, Interval2D> longestRange = new TreeMap<Interval, Interval2D>(new Comparator<Interval>() {
407 public int compare(Interval o1, Interval o2) {
408 int diff = o1.getStart() - o2.getStart();
409 if (diff == 0)
410 return o1.getEnd() - o2.getEnd();
411 else
412 return diff;
413 }
414 });
415 {
416
417 for (Interval2D each : rangeList) {
418 Interval key = each.startPoint();
419 if (longestRange.containsKey(key)) {
420 Interval2D prev = longestRange.get(key);
421 if (prev.maxLength() < each.maxLength()) {
422 longestRange.remove(key);
423 longestRange.put(key, each);
424 }
425 }
426 else
427 longestRange.put(key, each);
428 }
429 }
430 rangeList.clear();
431 rangeList.addAll(longestRange.values());
432
433 longestRange.clear();
434 {
435
436 for (Interval2D each : rangeList) {
437 Interval key = each.endPoint();
438 if (longestRange.containsKey(key)) {
439 Interval2D prev = longestRange.get(key);
440 if (prev.maxLength() < each.maxLength()) {
441 longestRange.remove(key);
442 longestRange.put(key, each);
443 }
444 }
445 else
446 longestRange.put(key, each);
447 }
448 }
449
450 rangeList.clear();
451 rangeList.addAll(longestRange.values());
452
453 }
454 _logger.info("# of unique paths : " + rangeList.size());
455
456 BufferedWriter pathOut = new BufferedWriter(new FileWriter(new File(outFolder, "paths.dot")));
457 try {
458
459 for (Interval2D each : rangeList) {
460 each.toDotFormat(pathOut);
461 }
462 }
463 finally {
464 pathOut.flush();
465 pathOut.close();
466 }
467
468
469
470
471 DisjointSet<Interval2D> clusterSet = new DisjointSet<Interval2D>();
472 {
473 _logger.info("clustring paths in X-coordinate...");
474 IntervalTree<Interval2D> xOverlapChecker = new IntervalTree<Interval2D>();
475 for (Interval2D each : rangeList) {
476 clusterSet.add(each);
477 for (Interval2D overlapped : xOverlapChecker.overlapQuery(each)) {
478 if (each.contains(overlapped) || overlapped.contains(each))
479 clusterSet.union(overlapped, each);
480 }
481 xOverlapChecker.add(each);
482 }
483
484 _logger.info("# of disjoint sets: " + clusterSet.rootNodeSet().size());
485 }
486
487 {
488 for (IntervalCluster cluster : createClusters(clusterSet)) {
489 BufferedWriter xClusterOut = new BufferedWriter(new FileWriter(new File(outFolder, String.format("x_cluster%03d.dot", cluster.id))));
490 try {
491
492 cluster.toDotFile(xClusterOut);
493 }
494 finally {
495 xClusterOut.flush();
496 xClusterOut.close();
497 }
498 }
499 }
500
501 {
502 _logger.info("clustring paths in Y-coordinate...");
503 IntervalTree<FlippedInterval2D> yOverlapChecker = new IntervalTree<FlippedInterval2D>();
504 for (Interval2D each : rangeList) {
505 FlippedInterval2D flip = new FlippedInterval2D(each);
506 for (FlippedInterval2D overlapped : yOverlapChecker.overlapQuery(flip)) {
507 if (flip.contains(overlapped) || overlapped.contains(flip))
508 clusterSet.union(overlapped.orig, each);
509 }
510 yOverlapChecker.add(flip);
511 }
512
513 {
514 for (IntervalCluster cluster : createClusters(clusterSet)) {
515 BufferedWriter yClusterOut = new BufferedWriter(new FileWriter(new File(outFolder, String.format("y_cluster%03d.dot", cluster.id))));
516 try {
517
518 cluster.toDotFile(yClusterOut);
519 }
520 finally {
521 yClusterOut.flush();
522 yClusterOut.close();
523 }
524 }
525 }
526
527
528 new SegmentReport().reportCluster(clusterSet);
529 }
530
531 _logger.info("done");
532 }
533
534 }
535
536 public List<IntervalCluster> createClusters(DisjointSet<Interval2D> clusterSet) {
537 List<IntervalCluster> clusterList = new ArrayList<IntervalCluster>();
538 Set<Interval2D> clusterRoots = clusterSet.rootNodeSet();
539
540 for (Interval2D root : clusterRoots) {
541 IntervalCluster cluster = new IntervalCluster(clusterSet.disjointSetOf(root));
542 clusterList.add(cluster);
543 }
544
545 Collections.sort(clusterList, new Comparator<IntervalCluster>() {
546 public int compare(IntervalCluster o1, IntervalCluster o2) {
547 return o2.length - o1.length;
548 }
549 });
550
551 int clusterCount = 1;
552 for (IntervalCluster each : clusterList) {
553 each.setId(clusterCount++);
554 }
555 return clusterList;
556 }
557
558 private List<Interval> mergeSegments(List<Interval> intervalList) {
559 Collections.sort(intervalList);
560 List<Interval> result = new ArrayList<Interval>();
561 Interval prev = null;
562 for (Interval each : intervalList) {
563 if (prev == null) {
564 prev = each;
565 continue;
566 }
567
568 if (prev.hasOverlap(each)) {
569 prev = new Interval(prev.getStart(), Math.max(prev.getEnd(), each.getEnd()));
570 }
571 else {
572 result.add(prev);
573 prev = each;
574 }
575 }
576 if (prev != null)
577 result.add(prev);
578
579 return result;
580 }
581
582 private class Segments {
583 public List<Interval> segments = new ArrayList<Interval>();
584 public List<Interval> reverseSegments = new ArrayList<Interval>();
585
586 public void merge() {
587 segments = mergeSegments(segments);
588 reverseSegments = mergeSegments(reverseSegments);
589 }
590
591 public int max(List<Interval> l) {
592 int max = 0;
593 for (Interval each : l)
594 if (max < each.length())
595 max = each.length();
596 return max;
597 }
598
599 public int maxLength() {
600 return Math.max(max(segments), max(reverseSegments));
601 }
602
603 public int size() {
604 return segments.size() + reverseSegments.size();
605 }
606 }
607
608 private Segments mergeSegments(IntervalCluster cluster) {
609
610 Segments seg = new Segments();
611 for (Interval2D each : cluster.component) {
612 int x1 = each.getStart();
613 int x2 = each.getEnd();
614 seg.segments.add(new Interval(x1, x2));
615 if (each.y1 < each.y2)
616 seg.segments.add(new Interval(each.y1, each.y2));
617 else
618 seg.reverseSegments.add(new Interval(each.y1, each.y2));
619 }
620
621 seg.merge();
622
623 return seg;
624 }
625
626 class SegmentReport {
627
628 int segmentID = 0;
629 int clusterID = 0;
630 final String sequence;
631 final SilkWriter silk;
632
633 public SegmentReport() throws IOException, UTGBException {
634
635
636 _logger.info("load fasta sequence: " + fastaFile);
637 FASTA fasta = new FASTA(fastaFile);
638 sequence = fasta.getRawSequence(chr);
639
640
641 File silkFile = new File(outFolder, "cluster-info.silk");
642 silk = new SilkWriter(new BufferedOutputStream(new FileOutputStream(silkFile)));
643 silk.preamble();
644 silk.leaf("date", new Date().toString());
645 silk.leaf("threshold", threshold);
646 silk.leaf("fasta", fastaFile);
647 silk.leaf("dot plot file", intervalFile);
648 }
649
650 public void reportCluster(DisjointSet<Interval2D> clusterSet) throws IOException, UTGBException {
651 Set<Interval2D> clusterRoots = clusterSet.rootNodeSet();
652 _logger.info("# of chains: " + clusterSet.numElements());
653 _logger.info("# of disjoint sets: " + clusterRoots.size());
654
655 List<IntervalCluster> clusterList = createClusters(clusterSet);
656 for (IntervalCluster cluster : clusterList) {
657 clusterID = cluster.id;
658 _logger.info(String.format("cluster %d:(%s)", cluster.id, cluster));
659
660 try {
661 cluster.validate();
662 File outFile = new File(outFolder, String.format("cluster%02d.fa", clusterID));
663 _logger.info("output " + outFile);
664 BufferedWriter fastaOut = new BufferedWriter(new FileWriter(outFile));
665 segmentID = 1;
666 Segments seg = mergeSegments(cluster);
667
668 SilkWriter sub = silk.node("cluster").attribute("id", Integer.toString(clusterID))
669 .attribute("max length", Integer.toString(seg.maxLength())).attribute("component size", Integer.toString(seg.size()));
670
671 outputSegments(seg.segments, sub, fastaOut, false);
672 outputSegments(seg.reverseSegments, sub, fastaOut, true);
673
674 fastaOut.close();
675 }
676 catch (UTGBException e) {
677 _logger.error(e);
678 }
679 }
680 silk.close();
681 }
682
683 public void outputSegments(List<Interval> segments, SilkWriter sub, BufferedWriter fastaOut, boolean isReverse) throws IOException, UTGBException {
684 for (Interval segment : segments) {
685 final int s = segment.getStart();
686 final int e = segment.getEnd();
687 final int id = segmentID++;
688
689 fastaOut.append(String.format(">c%02d-s%04d start:%d, end:%d, len:%d\n", clusterID, id, s, e, e - s));
690 String sSeq = sequence.substring(s - 1, e - 1);
691 sSeq = CompactACGT.createFromString(sSeq).reverseComplement().toString();
692 fastaOut.append(sSeq);
693 fastaOut.append("\n");
694
695 SilkWriter c = sub.node("component").attribute("id", Integer.toString(id)).attribute("x1", Integer.toString(s))
696 .attribute("x2", Integer.toString(e)).attribute("strand", isReverse ? "-" : "+").attribute("len", Integer.toString(e - s));
697
698 }
699 }
700
701 }
702
703 private void findPathsToLeaf(Interval2D current, Interval2D pathStart) {
704
705
706 List<Interval2D> outNodeList = graph.outNodeList(current);
707 if (outNodeList.isEmpty()) {
708
709 Interval2D range = new Interval2D(pathStart, current);
710 rangeList.add(range);
711 }
712 else {
713
714 for (Interval2D next : outNodeList) {
715 findPathsToLeaf(next, pathStart);
716 }
717 }
718
719 }
720
721 }