View Javadoc

1   /*--------------------------------------------------------------------------
2    *  Copyright 2010 utgenome.org
3    *
4    *  Licensed under the Apache License, Version 2.0 (the "License");
5    *  you may not use this file except in compliance with the License.
6    *  You may obtain a copy of the License at
7    *
8    *     http://www.apache.org/licenses/LICENSE-2.0
9    *
10   *  Unless required by applicable law or agreed to in writing, software
11   *  distributed under the License is distributed on an "AS IS" BASIS,
12   *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   *  See the License for the specific language governing permissions and
14   *  limitations under the License.
15   *--------------------------------------------------------------------------*/
16  //--------------------------------------
17  // utgb-core Project
18  //
19  // RepeatChain.java
20  // Since: 2010/10/19
21  //
22  //--------------------------------------
23  package org.utgenome.util.repeat;
24  
25  import java.io.BufferedOutputStream;
26  import java.io.BufferedReader;
27  import java.io.BufferedWriter;
28  import java.io.File;
29  import java.io.FileOutputStream;
30  import java.io.FileReader;
31  import java.io.FileWriter;
32  import java.io.IOException;
33  import java.io.Writer;
34  import java.util.ArrayList;
35  import java.util.Collections;
36  import java.util.Comparator;
37  import java.util.Date;
38  import java.util.List;
39  import java.util.Set;
40  import java.util.TreeMap;
41  
42  import org.utgenome.UTGBErrorCode;
43  import org.utgenome.UTGBException;
44  import org.utgenome.format.fasta.CompactACGT;
45  import org.utgenome.format.fasta.FASTA;
46  import org.utgenome.gwt.utgb.client.bio.Interval;
47  import org.utgenome.gwt.utgb.client.canvas.IntervalTree;
48  import org.xerial.lens.Lens;
49  import org.xerial.silk.SilkWriter;
50  import org.xerial.util.ObjectHandlerBase;
51  import org.xerial.util.graph.AdjacencyList;
52  import org.xerial.util.graph.Edge;
53  import org.xerial.util.log.Logger;
54  import org.xerial.util.opt.Argument;
55  import org.xerial.util.opt.Option;
56  import org.xerial.util.opt.OptionParser;
57  import org.xerial.util.opt.OptionParserException;
58  import org.xerial.util.text.TabAsTreeParser;
59  
60  /**
61   * Perform chaining of interval fragments
62   * 
63   * @author leo
64   * 
65   */
66  public class RepeatChainFinder {
67  
68  	private static Logger _logger = Logger.getLogger(RepeatChainFinder.class);
69  
70  	@Argument(index = 0)
71  	private File intervalFile;
72  	@Argument(index = 1)
73  	private File fastaFile;
74  
75  	@Option(symbol = "t", longName = "threshold", description = "threshold for connecting fragments")
76  	private int threshold = 50;
77  
78  	@Option(symbol = "s", description = "sequence name to read from FASTA")
79  	private String chr;
80  
81  	@Option(symbol = "o", description = "output folder")
82  	private String outFolder;
83  
84  	/**
85  	 * 2D interval
86  	 * 
87  	 * @author leo
88  	 * 
89  	 */
90  	public static class Interval2D extends Interval {
91  
92  		private static final long serialVersionUID = 1L;
93  
94  		public int y1;
95  		public int y2;
96  
97  		public Interval2D() {
98  		}
99  
100 		public Interval2D(int x1, int y1, int x2, int y2) {
101 			super(x1, x2);
102 			this.y1 = y1;
103 			this.y2 = y2;
104 		}
105 
106 		public Interval2D(Interval2D first, Interval2D last) {
107 			super(first.getStart(), last.getEnd());
108 			this.y1 = first.y1;
109 			this.y2 = last.y2;
110 		}
111 
112 		public int compareTo(Interval2D other) {
113 
114 			int diff = getStart() - other.getStart();
115 			if (diff != 0)
116 				return diff;
117 
118 			diff = y1 - other.y1;
119 			if (diff != 0)
120 				return diff;
121 
122 			diff = getEnd() - other.getEnd();
123 			if (diff != 0)
124 				return diff;
125 
126 			diff = y2 - other.y2;
127 			if (diff != 0)
128 				return diff;
129 
130 			return 0;
131 		}
132 
133 		public int maxLength() {
134 			return Math.max(super.length(), Math.abs(y2 - y1));
135 		}
136 
137 		@Override
138 		public String toString() {
139 			return String.format("max len:%d, (%d, %d)-(%d, %d)", maxLength(), getStart(), y1, getEnd(), y2);
140 		}
141 
142 		public Interval startPoint() {
143 			return new Interval(getStart(), y1);
144 		}
145 
146 		public Interval endPoint() {
147 			return new Interval(getEnd(), y2);
148 		}
149 
150 		public boolean isInLowerRightRegion() {
151 			return y1 < getStart();
152 		}
153 
154 		public boolean isForward() {
155 			return y1 <= y2;
156 		}
157 
158 		public int forwardDistance(Interval2D other) {
159 			int xDiff = Math.abs(other.getStart() - this.getEnd());
160 			int yDiff = Math.abs(other.y1 - this.y2);
161 			if (this.isForward()) {
162 				if (other.isForward())
163 					return Math.max(xDiff, yDiff);
164 				else
165 					return -1;
166 			}
167 			else {
168 				if (other.isForward())
169 					return -1;
170 				else
171 					return Math.max(xDiff, yDiff);
172 			}
173 		}
174 
175 		public void toDotFormat(Writer out) throws IOException {
176 			out.append(String.format("%d\t%d\t%d\t%d\n", getStart(), y1, getEnd(), y2));
177 		}
178 
179 		@Override
180 		public boolean equals(Object o) {
181 			Interval2D other = Interval2D.class.cast(o);
182 			return this.getStart() == other.getStart() && this.getEnd() == other.getEnd() && this.y1 == other.y1 && this.y2 == other.y2;
183 		}
184 
185 		@Override
186 		public int hashCode() {
187 			int hash = 3;
188 			hash += 137 * this.getStart();
189 			hash += 137 * this.getEnd();
190 			hash += 137 * this.y1;
191 			hash += 137 * this.y2;
192 			return hash;
193 		}
194 
195 	}
196 
197 	/**
198 	 * 
199 	 * @author leo
200 	 * 
201 	 */
202 	public static class FlippedInterval2D extends Interval {
203 		private static final long serialVersionUID = 1L;
204 		final Interval2D orig;
205 
206 		public FlippedInterval2D(Interval2D orig) {
207 			super(orig.y1, orig.y2);
208 			this.orig = orig;
209 		}
210 
211 	}
212 
213 	public static class IntervalCluster implements Comparable<IntervalCluster> {
214 
215 		public int id;
216 		public final List<Interval2D> component;
217 		public final int length;
218 
219 		public IntervalCluster(List<Interval2D> elements) {
220 			this.component = elements;
221 			Collections.sort(elements, new Comparator<Interval2D>() {
222 				public int compare(Interval2D o1, Interval2D o2) {
223 					return o2.maxLength() - o1.maxLength();
224 				}
225 			});
226 
227 			int maxLength = -1;
228 			for (Interval2D each : elements) {
229 				if (maxLength < each.maxLength())
230 					maxLength = each.maxLength();
231 			}
232 			this.length = maxLength;
233 		}
234 
235 		public void setId(int id) {
236 			this.id = id;
237 		}
238 
239 		public void validate() throws UTGBException {
240 
241 			if (component.size() <= 1)
242 				return;
243 
244 			for (Interval2D p : component) {
245 				boolean hasOverlap = false;
246 				for (Interval2D q : component) {
247 					if (p == q)
248 						continue;
249 					if (p.hasOverlap(q)) {
250 						hasOverlap = true;
251 						break;
252 					}
253 					else {
254 						Interval py = new Interval(p.y1, p.y2);
255 						Interval qy = new Interval(q.y1, q.y2);
256 						if (py.hasOverlap(qy)) {
257 							hasOverlap = true;
258 							break;
259 						}
260 					}
261 				}
262 				if (!hasOverlap)
263 					throw new UTGBException(UTGBErrorCode.ValidationFailure);
264 			}
265 		}
266 
267 		public int compareTo(IntervalCluster o) {
268 			return this.length - o.length;
269 		}
270 
271 		public int size() {
272 			return component.size();
273 		}
274 
275 		public void toDotFile(Writer out) throws IOException {
276 			for (Interval2D each : component) {
277 				each.toDotFormat(out);
278 			}
279 		}
280 
281 		@Override
282 		public String toString() {
283 			return String.format("length:%d, size:%d", length, size());
284 		}
285 
286 	}
287 
288 	public static void main(String[] args) {
289 		RepeatChainFinder finder = new RepeatChainFinder();
290 		OptionParser opt = new OptionParser(finder);
291 		try {
292 			opt.parse(args);
293 			finder.execute(args);
294 		}
295 		catch (OptionParserException e) {
296 			_logger.error(e);
297 		}
298 		catch (Exception e) {
299 			_logger.error(e);
300 			e.printStackTrace(System.err);
301 		}
302 	}
303 
304 	final AdjacencyList<Interval2D, Integer> graph = new AdjacencyList<Interval2D, Integer>();
305 	final ArrayList<Interval2D> rangeList = new ArrayList<Interval2D>();
306 
307 	public void execute(String[] args) throws Exception {
308 
309 		if (intervalFile == null)
310 			throw new UTGBException(UTGBErrorCode.MISSING_FILES, "no input file is given");
311 
312 		if (outFolder == null) {
313 			outFolder = String.format("target/cluster-T%d", threshold);
314 			new File(outFolder).mkdirs();
315 		}
316 
317 		{
318 			int numChain = 0;
319 			final List<Interval2D> intervals = new ArrayList<Interval2D>();
320 
321 			_logger.info("chain threshold: " + threshold);
322 
323 			_logger.info("loading intervals..");
324 			// import (x1, y1, x2, y2) tab-separated data
325 
326 			BufferedReader in = new BufferedReader(new FileReader(intervalFile));
327 			try {
328 				TabAsTreeParser t = new TabAsTreeParser(in);
329 				List<String> label = new ArrayList<String>();
330 				label.add("start");
331 				label.add("y1");
332 				label.add("end");
333 				label.add("y2");
334 				t.setColunLabel(label);
335 				t.setRowNodeName("entry");
336 				// load 2D intervals
337 				Lens.find(Interval2D.class, "entry", new ObjectHandlerBase<Interval2D>() {
338 					public void handle(Interval2D interval) throws Exception {
339 						if (!interval.isInLowerRightRegion())
340 							intervals.add(interval);
341 					}
342 
343 					@Override
344 					public void finish() throws Exception {
345 						_logger.info(String.format("loaded %d intervals", intervals.size()));
346 					}
347 				}, t);
348 			}
349 			finally {
350 				in.close();
351 			}
352 
353 			_logger.info("sorting intervals...");
354 			// sort intervals by their start order
355 			Collections.sort(intervals);
356 
357 			_logger.info("sweeping intervals...");
358 			{
359 				final IntervalTree<Interval2D> intervalTree = new IntervalTree<Interval2D>();
360 				for (Interval2D current : intervals) {
361 
362 					// sweep intervals in [-infinity, current.start - threshold)    
363 					intervalTree.removeBefore(current.getStart() - threshold);
364 
365 					// connect to the close intervals
366 					for (Interval2D each : intervalTree) {
367 						final int dist = each.forwardDistance(current);
368 						if (dist > 0 && dist < threshold) {
369 							graph.addEdge(each, current, dist);
370 						}
371 					}
372 
373 					intervalTree.add(current);
374 				}
375 			}
376 
377 			if (_logger.isTraceEnabled())
378 				_logger.trace("graph:\n" + graph.toGraphViz());
379 
380 			// creating a chain graph
381 			_logger.info("chaining...");
382 			for (Interval2D node : graph.getNodeLabelSet()) {
383 				List<Interval2D> adjacentNodes = new ArrayList<Interval2D>();
384 				for (Edge each : graph.getOutEdgeSet(node)) {
385 					adjacentNodes.add(graph.getNodeLabel(each.getDestNodeID()));
386 				}
387 
388 				if (_logger.isTraceEnabled())
389 					_logger.trace(String.format("node %s -> %s", node, adjacentNodes));
390 			}
391 
392 			// enumerate paths
393 			_logger.info("enumerating connected paths...");
394 			for (Interval2D each : graph.getNodeLabelSet()) {
395 				if (!graph.getInEdgeSet(each).isEmpty())
396 					continue;
397 
398 				// create chain
399 				findPathsToLeaf(each, each);
400 			}
401 			_logger.info("# of paths : " + rangeList.size());
402 
403 			// remove paths sharing the same start or end points
404 			{
405 				//Collections.sort(rangeList);
406 				TreeMap<Interval, Interval2D> longestRange = new TreeMap<Interval, Interval2D>(new Comparator<Interval>() {
407 					public int compare(Interval o1, Interval o2) {
408 						int diff = o1.getStart() - o2.getStart();
409 						if (diff == 0)
410 							return o1.getEnd() - o2.getEnd();
411 						else
412 							return diff;
413 					}
414 				});
415 				{
416 					// merge paths sharing start points
417 					for (Interval2D each : rangeList) {
418 						Interval key = each.startPoint();
419 						if (longestRange.containsKey(key)) {
420 							Interval2D prev = longestRange.get(key);
421 							if (prev.maxLength() < each.maxLength()) {
422 								longestRange.remove(key);
423 								longestRange.put(key, each);
424 							}
425 						}
426 						else
427 							longestRange.put(key, each);
428 					}
429 				}
430 				rangeList.clear();
431 				rangeList.addAll(longestRange.values());
432 
433 				longestRange.clear();
434 				{
435 					// merge paths sharing end points
436 					for (Interval2D each : rangeList) {
437 						Interval key = each.endPoint();
438 						if (longestRange.containsKey(key)) {
439 							Interval2D prev = longestRange.get(key);
440 							if (prev.maxLength() < each.maxLength()) {
441 								longestRange.remove(key);
442 								longestRange.put(key, each);
443 							}
444 						}
445 						else
446 							longestRange.put(key, each);
447 					}
448 				}
449 
450 				rangeList.clear();
451 				rangeList.addAll(longestRange.values());
452 
453 			}
454 			_logger.info("# of unique paths : " + rangeList.size());
455 
456 			BufferedWriter pathOut = new BufferedWriter(new FileWriter(new File(outFolder, "paths.dot")));
457 			try {
458 				//pathOut.append(String.format("%s\t%s\n", "src", "dest"));
459 				for (Interval2D each : rangeList) {
460 					each.toDotFormat(pathOut);
461 				}
462 			}
463 			finally {
464 				pathOut.flush();
465 				pathOut.close();
466 			}
467 
468 			//_logger.info(StringUtil.join(rangeSet, ",\n"));
469 
470 			// assign the overlapped intervals to the same cluster
471 			DisjointSet<Interval2D> clusterSet = new DisjointSet<Interval2D>();
472 			{
473 				_logger.info("clustring paths in X-coordinate...");
474 				IntervalTree<Interval2D> xOverlapChecker = new IntervalTree<Interval2D>();
475 				for (Interval2D each : rangeList) {
476 					clusterSet.add(each);
477 					for (Interval2D overlapped : xOverlapChecker.overlapQuery(each)) {
478 						if (each.contains(overlapped) || overlapped.contains(each))
479 							clusterSet.union(overlapped, each);
480 					}
481 					xOverlapChecker.add(each);
482 				}
483 
484 				_logger.info("# of disjoint sets: " + clusterSet.rootNodeSet().size());
485 			}
486 
487 			{
488 				for (IntervalCluster cluster : createClusters(clusterSet)) {
489 					BufferedWriter xClusterOut = new BufferedWriter(new FileWriter(new File(outFolder, String.format("x_cluster%03d.dot", cluster.id))));
490 					try {
491 						//xClusterOut.append(String.format("%s\t%s\n", "src", "dest"));
492 						cluster.toDotFile(xClusterOut);
493 					}
494 					finally {
495 						xClusterOut.flush();
496 						xClusterOut.close();
497 					}
498 				}
499 			}
500 
501 			{
502 				_logger.info("clustring paths in Y-coordinate...");
503 				IntervalTree<FlippedInterval2D> yOverlapChecker = new IntervalTree<FlippedInterval2D>();
504 				for (Interval2D each : rangeList) {
505 					FlippedInterval2D flip = new FlippedInterval2D(each);
506 					for (FlippedInterval2D overlapped : yOverlapChecker.overlapQuery(flip)) {
507 						if (flip.contains(overlapped) || overlapped.contains(flip))
508 							clusterSet.union(overlapped.orig, each);
509 					}
510 					yOverlapChecker.add(flip);
511 				}
512 
513 				{
514 					for (IntervalCluster cluster : createClusters(clusterSet)) {
515 						BufferedWriter yClusterOut = new BufferedWriter(new FileWriter(new File(outFolder, String.format("y_cluster%03d.dot", cluster.id))));
516 						try {
517 							//yClusterOut.append(String.format("%s\t%s\n", "src", "dest"));
518 							cluster.toDotFile(yClusterOut);
519 						}
520 						finally {
521 							yClusterOut.flush();
522 							yClusterOut.close();
523 						}
524 					}
525 				}
526 
527 				// report clusters
528 				new SegmentReport().reportCluster(clusterSet);
529 			}
530 
531 			_logger.info("done");
532 		}
533 
534 	}
535 
536 	public List<IntervalCluster> createClusters(DisjointSet<Interval2D> clusterSet) {
537 		List<IntervalCluster> clusterList = new ArrayList<IntervalCluster>();
538 		Set<Interval2D> clusterRoots = clusterSet.rootNodeSet();
539 
540 		for (Interval2D root : clusterRoots) {
541 			IntervalCluster cluster = new IntervalCluster(clusterSet.disjointSetOf(root));
542 			clusterList.add(cluster);
543 		}
544 
545 		Collections.sort(clusterList, new Comparator<IntervalCluster>() {
546 			public int compare(IntervalCluster o1, IntervalCluster o2) {
547 				return o2.length - o1.length;
548 			}
549 		});
550 
551 		int clusterCount = 1;
552 		for (IntervalCluster each : clusterList) {
553 			each.setId(clusterCount++);
554 		}
555 		return clusterList;
556 	}
557 
558 	private List<Interval> mergeSegments(List<Interval> intervalList) {
559 		Collections.sort(intervalList);
560 		List<Interval> result = new ArrayList<Interval>();
561 		Interval prev = null;
562 		for (Interval each : intervalList) {
563 			if (prev == null) {
564 				prev = each;
565 				continue;
566 			}
567 
568 			if (prev.hasOverlap(each)) {
569 				prev = new Interval(prev.getStart(), Math.max(prev.getEnd(), each.getEnd()));
570 			}
571 			else {
572 				result.add(prev);
573 				prev = each;
574 			}
575 		}
576 		if (prev != null)
577 			result.add(prev);
578 
579 		return result;
580 	}
581 
582 	private class Segments {
583 		public List<Interval> segments = new ArrayList<Interval>();
584 		public List<Interval> reverseSegments = new ArrayList<Interval>();
585 
586 		public void merge() {
587 			segments = mergeSegments(segments);
588 			reverseSegments = mergeSegments(reverseSegments);
589 		}
590 
591 		public int max(List<Interval> l) {
592 			int max = 0;
593 			for (Interval each : l)
594 				if (max < each.length())
595 					max = each.length();
596 			return max;
597 		}
598 
599 		public int maxLength() {
600 			return Math.max(max(segments), max(reverseSegments));
601 		}
602 
603 		public int size() {
604 			return segments.size() + reverseSegments.size();
605 		}
606 	}
607 
608 	private Segments mergeSegments(IntervalCluster cluster) {
609 
610 		Segments seg = new Segments();
611 		for (Interval2D each : cluster.component) {
612 			int x1 = each.getStart();
613 			int x2 = each.getEnd();
614 			seg.segments.add(new Interval(x1, x2));
615 			if (each.y1 < each.y2)
616 				seg.segments.add(new Interval(each.y1, each.y2));
617 			else
618 				seg.reverseSegments.add(new Interval(each.y1, each.y2));
619 		}
620 
621 		seg.merge();
622 
623 		return seg;
624 	}
625 
626 	class SegmentReport {
627 
628 		int segmentID = 0;
629 		int clusterID = 0;
630 		final String sequence;
631 		final SilkWriter silk;
632 
633 		public SegmentReport() throws IOException, UTGBException {
634 
635 			// fasta
636 			_logger.info("load fasta sequence: " + fastaFile);
637 			FASTA fasta = new FASTA(fastaFile);
638 			sequence = fasta.getRawSequence(chr);
639 
640 			// silk
641 			File silkFile = new File(outFolder, "cluster-info.silk");
642 			silk = new SilkWriter(new BufferedOutputStream(new FileOutputStream(silkFile)));
643 			silk.preamble();
644 			silk.leaf("date", new Date().toString());
645 			silk.leaf("threshold", threshold);
646 			silk.leaf("fasta", fastaFile);
647 			silk.leaf("dot plot file", intervalFile);
648 		}
649 
650 		public void reportCluster(DisjointSet<Interval2D> clusterSet) throws IOException, UTGBException {
651 			Set<Interval2D> clusterRoots = clusterSet.rootNodeSet();
652 			_logger.info("# of chains: " + clusterSet.numElements());
653 			_logger.info("# of disjoint sets: " + clusterRoots.size());
654 
655 			List<IntervalCluster> clusterList = createClusters(clusterSet);
656 			for (IntervalCluster cluster : clusterList) {
657 				clusterID = cluster.id;
658 				_logger.info(String.format("cluster %d:(%s)", cluster.id, cluster));
659 
660 				try {
661 					cluster.validate();
662 					File outFile = new File(outFolder, String.format("cluster%02d.fa", clusterID));
663 					_logger.info("output " + outFile);
664 					BufferedWriter fastaOut = new BufferedWriter(new FileWriter(outFile));
665 					segmentID = 1;
666 					Segments seg = mergeSegments(cluster);
667 
668 					SilkWriter sub = silk.node("cluster").attribute("id", Integer.toString(clusterID))
669 							.attribute("max length", Integer.toString(seg.maxLength())).attribute("component size", Integer.toString(seg.size()));
670 
671 					outputSegments(seg.segments, sub, fastaOut, false);
672 					outputSegments(seg.reverseSegments, sub, fastaOut, true);
673 
674 					fastaOut.close();
675 				}
676 				catch (UTGBException e) {
677 					_logger.error(e);
678 				}
679 			}
680 			silk.close();
681 		}
682 
683 		public void outputSegments(List<Interval> segments, SilkWriter sub, BufferedWriter fastaOut, boolean isReverse) throws IOException, UTGBException {
684 			for (Interval segment : segments) {
685 				final int s = segment.getStart();
686 				final int e = segment.getEnd();
687 				final int id = segmentID++;
688 				// output (x1, x2)
689 				fastaOut.append(String.format(">c%02d-s%04d start:%d, end:%d, len:%d\n", clusterID, id, s, e, e - s));
690 				String sSeq = sequence.substring(s - 1, e - 1);
691 				sSeq = CompactACGT.createFromString(sSeq).reverseComplement().toString();
692 				fastaOut.append(sSeq); // adjust to 0-origin
693 				fastaOut.append("\n");
694 				// output silk
695 				SilkWriter c = sub.node("component").attribute("id", Integer.toString(id)).attribute("x1", Integer.toString(s))
696 						.attribute("x2", Integer.toString(e)).attribute("strand", isReverse ? "-" : "+").attribute("len", Integer.toString(e - s));
697 				//c.leaf("seq", sSeq);
698 			}
699 		}
700 
701 	}
702 
703 	private void findPathsToLeaf(Interval2D current, Interval2D pathStart) {
704 
705 		// TODO cycle detection
706 		List<Interval2D> outNodeList = graph.outNodeList(current);
707 		if (outNodeList.isEmpty()) {
708 			// if this node is a leaf, report the path
709 			Interval2D range = new Interval2D(pathStart, current);
710 			rangeList.add(range);
711 		}
712 		else {
713 			// traverse children
714 			for (Interval2D next : outNodeList) {
715 				findPathsToLeaf(next, pathStart);
716 			}
717 		}
718 
719 	}
720 
721 }