/*- * #%L * Multiview stitching of large datasets. * %% * Copyright (C) 2016 - 2017 Big Stitcher developers. * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-2.0.html>. * #L% */ package net.preibisch.stitcher.algorithm; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Vector; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.stream.Collectors; import net.imglib2.Cursor; import net.imglib2.FinalInterval; import net.imglib2.Interval; import net.imglib2.IterableInterval; import net.imglib2.RandomAccessibleInterval; import net.imglib2.RealInterval; import net.imglib2.RealLocalizable; import net.imglib2.algorithm.phasecorrelation.PhaseCorrelation2; import net.imglib2.algorithm.phasecorrelation.PhaseCorrelationPeak2; import net.imglib2.img.array.ArrayImgFactory; import net.imglib2.img.display.imagej.ImageJFunctions; import net.imglib2.realtransform.AffineGet; import net.imglib2.realtransform.AffineTransform; import net.imglib2.realtransform.AffineTransform3D; import net.imglib2.realtransform.Translation; import net.imglib2.realtransform.Translation3D; import net.imglib2.realtransform.TranslationGet; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.complex.ComplexFloatType; import net.imglib2.type.numeric.integer.LongType; import net.imglib2.type.numeric.real.FloatType; import net.imglib2.util.Pair; import net.imglib2.util.Util; import net.imglib2.util.ValuePair; import net.imglib2.view.Views; import net.preibisch.legacy.io.IOFunctions; import net.preibisch.mvrecon.fiji.spimdata.stitchingresults.PairwiseStitchingResult; import net.preibisch.mvrecon.process.export.DisplayImage; import net.preibisch.mvrecon.process.fusion.FusionTools; import net.preibisch.mvrecon.process.fusion.ImagePortion; import net.preibisch.mvrecon.process.interestpointregistration.TransformationTools; import net.preibisch.mvrecon.process.interestpointregistration.pairwise.constellation.grouping.Group; import net.preibisch.stitcher.algorithm.lucaskanade.Align; import net.preibisch.stitcher.algorithm.lucaskanade.LucasKanadeParameters; import net.preibisch.stitcher.input.FractalImgLoader; import net.preibisch.stitcher.input.FractalSpimDataGenerator; public class PairwiseStitching { public static <T extends RealType< T >, S extends RealType< S >> Pair< AffineTransform, Double > getShiftLucasKanade( final RandomAccessibleInterval< T > input1, final RandomAccessibleInterval< T > input2, final TranslationGet t1, final TranslationGet t2, final LucasKanadeParameters params, final ExecutorService service) { // TODO: allow arbitrary pre-registration // check if we have singleton dimensions boolean[] singletonDims = new boolean[input1.numDimensions()]; for ( int d = 0; d < input1.numDimensions(); ++d ) singletonDims[d] = !( input1.dimension( d ) > 1 && input2.dimension( d ) > 1 ); // TODO: should we consider cases where a dimension is singleton in one // image but not the other? final RealInterval transformed1 = TransformTools.applyTranslation( input1, t1, singletonDims ); final RealInterval transformed2 = TransformTools.applyTranslation( input2, t2, singletonDims ); final RandomAccessibleInterval< T > img1; final RandomAccessibleInterval< T > img2; // make sure everything is zero-min if ( !Views.isZeroMin( input1 ) ) img1 = Views.dropSingletonDimensions( Views.zeroMin( input1 ) ); else img1 = Views.dropSingletonDimensions( input1 ); if ( !Views.isZeroMin( input2 ) ) img2 = Views.dropSingletonDimensions( Views.zeroMin( input2 ) ); else img2 = Views.dropSingletonDimensions( input2 ); System.out.println( "1: " + Util.printInterval( img1 ) ); System.out.println( "1: " + TransformationTools.printRealInterval( transformed1 ) ); System.out.println( "2: " + Util.printInterval( img2 ) ); System.out.println( "2: " + TransformationTools.printRealInterval( transformed2 ) ); final RealInterval overlap = TransformTools.getOverlap( transformed1, transformed2 ); System.out.println( "O: " + TransformationTools.printRealInterval( overlap ) ); // not overlapping if ( overlap == null ) return null; final RealInterval localOverlap1 = TransformTools.getLocalOverlap( transformed1, overlap ); final RealInterval localOverlap2 = TransformTools.getLocalOverlap( transformed2, overlap ); final Interval interval1 = TransformTools.getLocalRasterOverlap( localOverlap1 ); final Interval interval2 = TransformTools.getLocalRasterOverlap( localOverlap2 ); System.out.println( "1: " + TransformationTools.printRealInterval( localOverlap1 ) ); System.out.println( "1: " + Util.printInterval( interval1 ) ); System.out.println( "2: " + TransformationTools.printRealInterval( localOverlap2 ) ); System.out.println( "2: " + Util.printInterval( interval2 ) ); // check whether we have 0-sized (or negative sized) or unequal raster // overlapIntervals // (this should just happen with overlaps < 1px in some dimension) // ignore this pair in that case for ( int d = 0; d < interval1.numDimensions(); ++d ) { if ( interval1.dimension( d ) <= 0 || interval2.dimension( d ) <= 0 || interval1.dimension( d ) != interval2.dimension( d ) ) { System.out.println( "Rastered overlap volume is zero, skipping." ); return null; } } // do the alignment Align< T > lkAlign = new Align< T >( Views.zeroMin( Views.interval( img1, interval1 ) ), new ArrayImgFactory< FloatType >(), params.getWarpFunctionInstance( img1.numDimensions() ) ); AffineTransform res = lkAlign.align( Views.zeroMin( Views.interval( img2, interval2 ) ), params.maxNumIterations, params.minParameterChange ); if (lkAlign.didConverge()) IOFunctions.println("(" + new Date( System.currentTimeMillis() ) + ") determined transformation:" + Util.printCoordinates( res.getRowPackedCopy() ) ); else IOFunctions.println("(" + new Date( System.currentTimeMillis() ) + ") registration did not converge" ); final int nFull = input1.numDimensions(); AffineTransform resFull = new AffineTransform( nFull ); // increase dimensionality of transform if necessary int dReducedDims = 0; for ( int d = 0; d < nFull; ++d ) { if (! singletonDims[d] ) { int dReducedDimsCol = 0; for ( int dCol = 0; dCol < nFull + 1; ++dCol ) { if (dCol == nFull || !singletonDims[dCol] ) { resFull.set( res.get( dReducedDims, dReducedDimsCol ), d, dCol ); dReducedDimsCol++; } } dReducedDims++; } } // get subpixel offset before alignment final double[] subpixelOffset = new double[ nFull ]; int d2 = 0; for ( int d = 0; d < nFull; ++d ) { if ( singletonDims[d] ) { // NOP, we did not calculate any transformation in this dimension } else { // correct for the int/real coordinate mess final double intervalSubpixelOffset1 = interval1.realMin( d2 ) - localOverlap1.realMin( d2 ); // a_s final double intervalSubpixelOffset2 = interval2.realMin( d2 ) - localOverlap2.realMin( d2 ); // b_s subpixelOffset[d] = ( intervalSubpixelOffset2 - intervalSubpixelOffset1 ); d2++; } } // correct for subpixel offset final AffineTransform subpixelT = new AffineTransform( nFull ); for (int d = 0; d<nFull; d++) subpixelT.set( subpixelOffset[d], d, nFull ); resFull.preConcatenate( subpixelT ); return new ValuePair<>( resFull, lkAlign.didConverge() ? lkAlign.getCurrentCorrelation( Views.zeroMin( Views.interval( img2, interval2 ) ) ) : 0.0 ); } /** * The absolute shift of input2 relative to after PCM input1 (without t1 and * t2 - they just help to speed it up) * * @param input1 - zero-min interval, starting at (0,0,...) * @param input2 - zero-min interval, starting at (0,0,...) * @param t1 - translation of input1 * @param t2 - translation of input2 * @param params - stitching parameters * @param service - executor service to use * @param <T> pixel type input1 * @param <S> pixel type input2 * @return pair of shift vector and cross correlation coefficient or null if no shift could be determined */ public static <T extends RealType< T >, S extends RealType< S >> Pair< Translation, Double > getShift( final RandomAccessibleInterval< T > input1, final RandomAccessibleInterval< S > input2, final TranslationGet t1, final TranslationGet t2, final PairwiseStitchingParameters params, final ExecutorService service) { // check if we have singleton dimensions boolean[] singletonDims = new boolean[input1.numDimensions()]; for ( int d = 0; d < input1.numDimensions(); ++d ) singletonDims[d] = !(input1.dimension( d ) > 1 && input2.dimension( d ) > 1); // TODO: should we consider cases where a dimension is singleton in one image but not the other? final RealInterval transformed1 = TransformTools.applyTranslation( input1, params.useWholeImage ? new Translation3D() : t1, singletonDims ); final RealInterval transformed2 = TransformTools.applyTranslation( input2, params.useWholeImage ? new Translation3D() : t2, singletonDims ); final RandomAccessibleInterval< T > img1; final RandomAccessibleInterval< S > img2; // make sure it is zero-min and drop singleton dimensions if ( !Views.isZeroMin( input1 ) ) img1 = Views.dropSingletonDimensions( Views.zeroMin( input1 )); else img1 = Views.dropSingletonDimensions(input1); if ( !Views.isZeroMin( input2 ) ) img2 = Views.dropSingletonDimensions( Views.zeroMin( input2 ) ); else img2 = Views.dropSingletonDimensions( input2 ); // echo intervals System.out.println( "1: " + Util.printInterval( img1 ) ); System.out.println( "1: " + TransformationTools.printRealInterval( transformed1 ) ); System.out.println( "2: " + Util.printInterval( img2 ) ); System.out.println( "2: " + TransformationTools.printRealInterval( transformed2 ) ); // get overlap interval final RealInterval overlap = TransformTools.getOverlap( transformed1, transformed2 ); System.out.println( "O: " + TransformationTools.printRealInterval( overlap ) ); // not overlapping -> we wont be able to determine a shift if ( overlap == null ) return null; // get overlap in images' coordinates final RealInterval localOverlap1 = TransformTools.getLocalOverlap( transformed1, overlap ); final RealInterval localOverlap2 = TransformTools.getLocalOverlap( transformed2, overlap ); // round to integer interval final Interval interval1 = TransformTools.getLocalRasterOverlap( localOverlap1 ); final Interval interval2 = TransformTools.getLocalRasterOverlap( localOverlap2 ); // echo intervals System.out.println( "1: " + TransformationTools.printRealInterval( localOverlap1 ) ); System.out.println( "1: " + Util.printInterval( interval1 ) ); System.out.println( "2: " + TransformationTools.printRealInterval( localOverlap2 ) ); System.out.println( "2: " + Util.printInterval( interval2 ) ); // check whether we have 0-sized (or negative sized) or unequal raster overlapIntervals // (this should just happen with overlaps < 1px in some dimension) // ignore this pair in that case // FIXED for downsampling=2 caused by up/down-rounding (see TransformTools.getLocalRasterOverlap) // TODO: in pre-transformed views (e.g. both rotated), we might sometimes have unequal overlap due to numerical imprecision? // -> look into this (still not fixed!) >> should be fixed now for (int d = 0; d < interval1.numDimensions(); ++d) { if ( interval1.dimension( d ) <= 0 || interval2.dimension( d ) <= 0 ) { IOFunctions.println( "Rastered overlap between volumes is zero, skipping." ); return null; } if ( interval1.dimension( d ) != interval2.dimension( d ) ) { IOFunctions.println( "Rastered overlap between volumes in dim " + d + " is unequal ("+interval1.dimension( d )+"<>"+interval2.dimension( d )+"), skipping." ); return null; } } // // call the phase correlation // final int[] extension = new int[img1.numDimensions()]; Arrays.fill( extension, 10 ); // // the min overlap is in percent of the current overlap interval // long minOverlap = 1; for (int d = 0; d < interval1.numDimensions(); d++) minOverlap *= interval1.dimension( d ); minOverlap *= params.minOverlap; //System.out.println( "Min overlap is: " + minOverlap ); System.out.println( "FFT" ); // TODO: Do not extend by mirror inside, but do that out here on the // full image, // so we feed it RandomAccessible + an Interval we want to use for the // PCM > also zero-min inside final RandomAccessibleInterval< FloatType > pcm = PhaseCorrelation2.calculatePCM( Views.zeroMin( Views.interval( img1, interval1 ) ), Views.zeroMin( Views.interval( img2, interval2 ) ), extension, new ArrayImgFactory< FloatType >(), new FloatType(), new ArrayImgFactory< ComplexFloatType >(), new ComplexFloatType(), service ); normalizePCM( pcm, service ); final PhaseCorrelationPeak2 shiftPeak = PhaseCorrelation2.getShift( pcm, Views.zeroMin( Views.interval( img1, interval1 ) ), Views.zeroMin( Views.interval( img2, interval2 ) ), params.peaksToCheck, minOverlap, params.doSubpixel, params.interpolateCrossCorrelation, service ); //System.out.println( "Actual overlap of best shift is: " + shiftPeak.getnPixel() ); // the best peak is horrible or no peaks were found at all, return null if ( shiftPeak == null || Double.isInfinite( shiftPeak.getCrossCorr() ) ) return null; final RealLocalizable shift; if ( shiftPeak.getSubpixelShift() == null ) shift = shiftPeak.getShift(); else shift = shiftPeak.getSubpixelShift(); // final, relative shift final double[] finalShift = new double[input1.numDimensions()]; int d2 = 0; for ( int d = 0; d < input1.numDimensions(); ++d ) { // we ignored these axes during phase correlation -> set their shift to 0 if (singletonDims[d]) { finalShift[d] = 0.0; } else { // correct for the int/real coordinate mess final double intervalSubpixelOffset1 = interval1.realMin( d2 ) - localOverlap1.realMin( d2 ); // a_s final double intervalSubpixelOffset2 = interval2.realMin( d2 ) - localOverlap2.realMin( d2 ); // b_s final double localRasterShift = shift.getDoublePosition( d2 ); // d' System.out.println( intervalSubpixelOffset1 + "," + intervalSubpixelOffset2 + "," + localRasterShift ); final double localRelativeShift = localRasterShift - ( intervalSubpixelOffset2 - intervalSubpixelOffset1 ); finalShift[d] = localRelativeShift; d2++; } // if we used the whole image, subtract existing shift if (params.useWholeImage) { finalShift[d] -= t2.getTranslation( d ) - t1.getTranslation( d ); } } return new ValuePair< >( new Translation(finalShift), shiftPeak.getCrossCorr() ); } public static void normalizePCM( final RandomAccessibleInterval< FloatType > pcm, final ExecutorService service ) { // so that the peak doesn't stick out too much, that interferes with the subpixel detection final float min = min( pcm, service ); adjustPCM( pcm, min, service ); } public static float min( final RandomAccessibleInterval< FloatType > img, final ExecutorService taskExecutor ) { final IterableInterval< FloatType > iterable = Views.iterable( img ); // split up into many parts for multithreading final Vector< ImagePortion > portions = FusionTools.divideIntoPortions( iterable.size() ); // set up executor service final ArrayList< Callable< Float > > tasks = new ArrayList<>(); for ( final ImagePortion portion : portions ) { tasks.add( new Callable< Float >() { @Override public Float call() throws Exception { float min = Float.MAX_VALUE; final Cursor< FloatType > c = iterable.cursor(); c.jumpFwd( portion.getStartPosition() ); for ( long j = 0; j < portion.getLoopSize(); ++j ) min = Math.min( min, c.next().get() ); // min & max of this portion return min; } }); } // run threads and combine results float min = Float.MAX_VALUE; try { // invokeAll() returns when all tasks are complete final List< Future< Float > > futures = taskExecutor.invokeAll( tasks ); for ( final Future< Float > future : futures ) min = Math.min( min, future.get() ); } catch ( final Exception e ) { IOFunctions.println( "Failed to compute min: " + e ); e.printStackTrace(); return Float.NaN; } return min; } public static void adjustPCM( final RandomAccessibleInterval< FloatType > img, final float min, final ExecutorService taskExecutor ) { final IterableInterval< FloatType > iterable = Views.iterable( img ); // split up into many parts for multithreading final Vector< ImagePortion > portions = FusionTools.divideIntoPortions( iterable.size() ); // set up executor service final ArrayList< Callable< Void > > tasks = new ArrayList<>(); for ( final ImagePortion portion : portions ) { tasks.add( new Callable< Void >() { @Override public Void call() throws Exception { final Cursor< FloatType > c = iterable.cursor(); c.jumpFwd( portion.getStartPosition() ); for ( long j = 0; j < portion.getLoopSize(); ++j ) { final FloatType t = c.next(); t.set( (float)Math.sqrt( t.get() - min + 0.01 ) ); } return null; } }); } try { // invokeAll() returns when all tasks are complete taskExecutor.invokeAll( tasks ); } catch ( final Exception e ) { IOFunctions.println( "Failed to subtract: " + e ); } } public static <T extends RealType< T >, C extends Comparable< C >> List< PairwiseStitchingResult< C > > getPairwiseShiftsLucasKanade( final Map< C, RandomAccessibleInterval< T > > rais, final Map< C, TranslationGet > translations, final LucasKanadeParameters params, final ExecutorService service) { List< C > indexes = new ArrayList< >( rais.keySet() ); Collections.sort( indexes ); List< PairwiseStitchingResult< C > > result = new ArrayList< >(); // got through all pairs with index1 < index2 for ( int i = 0; i < indexes.size(); i++ ) { for ( int j = i + 1; j < indexes.size(); j++ ) { Pair< AffineTransform, Double > resT = getShiftLucasKanade( rais.get( indexes.get( i ) ), rais.get( indexes.get( j ) ), translations.get( indexes.get( i ) ), translations.get( indexes.get( j ) ), params, service ); if ( resT != null ) { Set<C> setA = new HashSet<>(); setA.add( indexes.get( i ) ); Set<C> setB = new HashSet<>(); setA.add( indexes.get( j ) ); Pair< Group<C>, Group<C> > key = new ValuePair<>(new Group<>(setA), new Group<>(setB)); result.add( new PairwiseStitchingResult< C >( key, null, resT.getA() , resT.getB(), 0.0 ) ); } } } return result; } public static <T extends RealType< T >, C extends Comparable< C >> List< PairwiseStitchingResult< C > > getPairwiseShifts( final Map< C, RandomAccessibleInterval< T > > rais, final Map< C, TranslationGet > translations, final PairwiseStitchingParameters params, final ExecutorService service) { List< C > indexes = new ArrayList< >( rais.keySet() ); Collections.sort( indexes ); List< PairwiseStitchingResult< C > > result = new ArrayList< >(); // got through all pairs with index1 < index2 for ( int i = 0; i < indexes.size(); i++ ) { for ( int j = i + 1; j < indexes.size(); j++ ) { final Pair< Translation, Double > resT = getShift( rais.get( indexes.get( i ) ), rais.get( indexes.get( j ) ), translations.get( indexes.get( i ) ), translations.get( indexes.get( j ) ), params, service ); if ( resT != null ) { Set<C> setA = new HashSet<>(); setA.add( indexes.get( i ) ); Set<C> setB = new HashSet<>(); setA.add( indexes.get( j ) ); Pair< Group<C>, Group<C> > key = new ValuePair<>(new Group<>(setA), new Group<>(setB)); result.add( new PairwiseStitchingResult< C >( key, null, resT.getA(), resT.getB(), 0.0 ) ); } } } return result; } public static void main(String[] args) { final AffineTransform3D m = new AffineTransform3D(); double scale = 200; m.set( scale, 0.0f, 0.0f, 0.0f, 0.0f, scale, 0.0f, 0.0f, 0.0f, 0.0f, scale, 0.0f ); final AffineTransform3D mShift = new AffineTransform3D(); double shift = 100; mShift.set( 1.0f, 0.0f, 0.0f, shift, 0.0f, 1.0f, 0.0f, shift, 0.0f, 0.0f, 1.0f, shift ); final AffineTransform3D mShift2 = new AffineTransform3D(); double shift2x = 1200; double shift2y = 300; mShift2.set( 1.0f, 0.0f, 0.0f, shift2x, 0.0f, 1.0f, 0.0f, shift2y, 0.0f, 0.0f, 1.0f, 0.0f ); final AffineTransform3D mShift3 = new AffineTransform3D(); double shift3x = 500; double shift3y = 1300; mShift3.set( 1.0f, 0.0f, 0.0f, shift3x, 0.0f, 1.0f, 0.0f, shift3y, 0.0f, 0.0f, 1.0f, 0.0f ); AffineTransform3D m2 = m.copy(); AffineTransform3D m3 = m.copy(); m.preConcatenate( mShift ); m2.preConcatenate( mShift2 ); m3.preConcatenate( mShift3 ); Interval start = new FinalInterval( new long[] { -399, -399, 0 }, new long[] { 0, 0, 1 } ); List< Interval > intervals = FractalSpimDataGenerator.generateTileList( start, 7, 6, 0.2f ); List< Interval > falseStarts = FractalSpimDataGenerator.generateTileList( start, 7, 6, 0.30f ); FractalSpimDataGenerator fsdg = new FractalSpimDataGenerator( 3 ); fsdg.addFractal( m ); fsdg.addFractal( m2 ); fsdg.addFractal( m3 ); Map< Integer, RandomAccessibleInterval< LongType > > rais = new HashMap< >(); Map< Integer, TranslationGet > tr = new HashMap< >(); List< TranslationGet > tileTranslations = FractalSpimDataGenerator.getTileTranslations( falseStarts ); FractalImgLoader imgLoader = (FractalImgLoader) fsdg.generateSpimData( intervals ).getSequenceDescription() .getImgLoader(); for ( int i = 0; i < intervals.size(); i++ ) { rais.put( i, imgLoader.getImageAtInterval( intervals.get( i ) ) ); tr.put( i, tileTranslations.get( i ) ); } List< PairwiseStitchingResult< Integer > > pairwiseShifts = getPairwiseShifts( rais, tr, new PairwiseStitchingParameters(), Executors.newFixedThreadPool( Runtime.getRuntime().availableProcessors() ) ); Map< Integer, AffineGet > collect = tr.entrySet().stream().collect( Collectors.toMap( e -> e.getKey(), e -> {AffineTransform3D res = new AffineTransform3D(); res.set( e.getValue().getRowPackedCopy() ); return res; } )); // TODO: replace with new globalOpt code // Map< Set<Integer>, AffineGet > globalOptimization = GlobalTileOptimization.twoRoundGlobalOptimization( new TranslationModel3D(), // rais.keySet().stream().map( ( c ) -> {Set<Integer> s = new HashSet<>(); s.add( c ); return s;}).collect( Collectors.toList() ), // null, // collect, // pairwiseShifts, new GlobalOptimizationParameters() ); // // System.out.println( globalOptimization ); } }