diff --git a/STalign/STalign.py b/STalign/STalign.py index 23af422..a5b6f3c 100644 --- a/STalign/STalign.py +++ b/STalign/STalign.py @@ -916,7 +916,7 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, a=500.0,p=2.0,expand=2.0,nt=3, niter=5000,diffeo_start=0, epL=2e-8, epT=2e-1, epV=2e3, sigmaM=1.0,sigmaB=2.0,sigmaA=5.0,sigmaR=5e5,sigmaP=2e1, - device='cpu',dtype=torch.float64, muB=None, muA=None): + device='cpu',dtype=torch.float64, muB=None, muA=None, display=True): ''' Run LDDMM between a pair of images. This jointly estimates an affine transform A, and a diffeomorphism phi. @@ -995,6 +995,9 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, If the target is a grayscale image, this should be a tensor of size 1. muB: torch tensor whose dimension is the same as the target image Defaults to None, which means we estimate this. If you provide a value, we will not estimate it. + display: binary + Defaults to True + Decides if the plots of the function will be shown Returns a dictionary ------- @@ -1011,6 +1014,8 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, Resulting weight 2D array (background) 'WA': torch tensor Resulting weight 2D array (artifact) + 'Errors': list + List of the progresion of the errors in algingment } ''' @@ -1089,10 +1094,11 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, #ax.imshow(K[0].cpu()) DV = torch.prod(dv) Ki = torch.fft.ifftn(K).real - fig,ax = plt.subplots() - ax.imshow(Ki.clone().detach().cpu().numpy(),vmin=0.0,extent=extentV) - ax.set_title('smoothing kernel') - fig.canvas.draw() + if display: + fig,ax = plt.subplots() + ax.imshow(Ki.clone().detach().cpu().numpy(),vmin=0.0,extent=extentV) + ax.set_title('smoothing kernel') + fig.canvas.draw() # nt = 3 @@ -1137,9 +1143,10 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, else: estimate_muB = False - fig,ax = plt.subplots(2,3) - ax = ax.ravel() - figE,axE = plt.subplots(1,3) + if display: + fig,ax = plt.subplots(2,3) + ax = ax.ravel() + figE,axE = plt.subplots(1,3) Esave = [] try: @@ -1253,57 +1260,58 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, # draw if not it%10: - ax[0].cla() - ax[0].imshow( ((AI-torch.amin(AI,(1,2))[...,None,None])/(torch.amax(AI,(1,2))-torch.amin(AI,(1,2)))[...,None,None]).permute(1,2,0).clone().detach().cpu(),extent=extentJ) - ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) - ax[0].set_title('space tformed source') - - ax[1].cla() - ax[1].imshow(clip(fAI.permute(1,2,0).clone().detach()/torch.max(J).item()).cpu(),extent=extentJ) - ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) - ax[1].set_title('contrast tformed source') - - ax[5].cla() - ax[5].imshow(clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,0).clone().detach().cpu()*0.5+0.5,extent=extentJ) - ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) - ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) - ax[5].set_title('Error') - - ax[2].cla() - ax[2].imshow(J.permute(1,2,0).cpu()/torch.max(J).item(),extent=extentJ) - ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) - ax[2].set_title('Target') - - ax[4].cla() - ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu(),extent=extentJ) - ax[4].set_title('Weights') - - - toshow = v[0].clone().detach().cpu() - toshow /= torch.max(torch.abs(toshow)) - toshow = toshow*0.5+0.5 - toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1) - ax[3].cla() - ax[3].imshow(clip(toshow),extent=extentV) - ax[3].set_title('velocity') - - axE[0].cla() - axE[0].plot(Esave) - axE[0].legend(['E','EM','ER','EP']) - axE[0].set_yscale('log') - axE[1].cla() - axE[1].plot([e[:2] for e in Esave]) - axE[1].legend(['E','EM']) - axE[1].set_yscale('log') - axE[2].cla() - axE[2].plot([e[2] for e in Esave]) - axE[2].legend(['ER']) - axE[2].set_yscale('log') - - - - fig.canvas.draw() - figE.canvas.draw() + if display: + ax[0].cla() + ax[0].imshow( ((AI-torch.amin(AI,(1,2))[...,None,None])/(torch.amax(AI,(1,2))-torch.amin(AI,(1,2)))[...,None,None]).permute(1,2,0).clone().detach().cpu(),extent=extentJ) + ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) + ax[0].set_title('space tformed source') + + ax[1].cla() + ax[1].imshow(clip(fAI.permute(1,2,0).clone().detach()/torch.max(J).item()).cpu(),extent=extentJ) + ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) + ax[1].set_title('contrast tformed source') + + ax[5].cla() + ax[5].imshow(clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,0).clone().detach().cpu()*0.5+0.5,extent=extentJ) + ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) + ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) + ax[5].set_title('Error') + + ax[2].cla() + ax[2].imshow(J.permute(1,2,0).cpu()/torch.max(J).item(),extent=extentJ) + ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) + ax[2].set_title('Target') + + ax[4].cla() + ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu(),extent=extentJ) + ax[4].set_title('Weights') + + + toshow = v[0].clone().detach().cpu() + toshow /= torch.max(torch.abs(toshow)) + toshow = toshow*0.5+0.5 + toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1) + ax[3].cla() + ax[3].imshow(clip(toshow),extent=extentV) + ax[3].set_title('velocity') + + axE[0].cla() + axE[0].plot(Esave) + axE[0].legend(['E','EM','ER','EP']) + axE[0].set_yscale('log') + axE[1].cla() + axE[1].plot([e[:2] for e in Esave]) + axE[1].legend(['E','EM']) + axE[1].set_yscale('log') + axE[2].cla() + axE[2].plot([e[2] for e in Esave]) + axE[2].legend(['ER']) + axE[2].set_yscale('log') + + + + fig.canvas.draw() + figE.canvas.draw() return { 'A': A.clone().detach(), @@ -1311,7 +1319,8 @@ def LDDMM(xI,I,xJ,J,pointsI=None,pointsJ=None, 'xv': xv, 'WM': WM.clone().detach(), 'WB': WB.clone().detach(), - 'WA': WA.clone().detach() + 'WA': WA.clone().detach(), + 'Errors': Esave, } @@ -1320,7 +1329,7 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None, a=500.0,p=2.0,expand=1.25,nt=3, niter=5000,diffeo_start=0, epL=1e-6, epT=1e1, epV=1e3, sigmaM=1.0,sigmaB=2.0,sigmaA=5.0,sigmaR=1e8,sigmaP=2e1, - device='cpu',dtype=torch.float64, muA=None, muB = None): + device='cpu',dtype=torch.float64, muA=None, muB = None, display=True): ''' LDDMM for 3D to 2D slice mapping. muA: torch tensor whose dimension is the same as the target image @@ -1395,10 +1404,11 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None, #ax.imshow(K[0].cpu()) DV = torch.prod(dv) Ki = torch.fft.ifftn(K).real - fig,ax = plt.subplots() - ax.imshow(Ki[Ki.shape[0]//2].clone().detach().cpu().numpy(),vmin=0.0,extent=extentV) - ax.set_title('smoothing kernel') - fig.canvas.draw() + if display: + fig,ax = plt.subplots() + ax.imshow(Ki[Ki.shape[0]//2].clone().detach().cpu().numpy(),vmin=0.0,extent=extentV) + ax.set_title('smoothing kernel') + fig.canvas.draw() # steps epL = torch.tensor(epL,device=device,dtype=dtype) @@ -1442,10 +1452,11 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None, ''' # a figure - fig,ax = plt.subplots(2,3) - ax = ax.ravel() - figE,axE = plt.subplots(1,3) - axE = axE.ravel() + if display: + fig,ax = plt.subplots(2,3) + ax = ax.ravel() + figE,axE = plt.subplots(1,3) + axE = axE.ravel() Esave = [] # zero gradients try: @@ -1564,60 +1575,61 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None, # draw if not it%10: - ax[0].cla() - Ishow = ((AI-torch.amin(AI,(1,2,3))[...,None,None])/(torch.amax(AI,(1,2,3))-torch.amin(AI,(1,2,3)))[...,None,None,None]).permute(1,2,3,0).clone().detach().cpu() - ax[0].imshow( Ishow[0,...,0] ,extent=extentJ) - #ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) - ax[0].set_title('space tformed source') - - ax[1].cla() - Ishow = clip(fAI.permute(1,2,3,0).clone().detach()/torch.max(J).item()).cpu() - ax[1].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1) - #ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) - ax[1].set_title('contrast tformed source') - - ax[5].cla() - Ishow = clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,3,0).clone().detach().cpu()*0.5+0.5 - ax[5].imshow(Ishow[0,...,0],extent=extentJ) - #ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) - #ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) - ax[5].set_title('Error') - - ax[2].cla() - Ishow = J.permute(1,2,3,0).cpu()/torch.max(J).item() - ax[2].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1) - #ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) - ax[2].set_title('Target') - - ax[4].cla() - ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu()[0],extent=extentJ) - ax[4].set_title('Weights') - - - toshow = v[0].clone().detach().cpu() # initial velocity, components are rgb - toshow /= torch.max(torch.abs(toshow)) - toshow = toshow*0.5+0.5 - #toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1) - ax[3].cla() - ax[3].imshow(clip(toshow)[toshow.shape[0]//2],extent=extentV) - ax[3].set_title('velocity') - - axE[0].cla() - axE[0].plot(Esave) - axE[0].legend(['E','EM','ER','EP']) - axE[0].set_yscale('log') - axE[1].cla() - axE[1].plot([e[:2] for e in Esave]) - axE[1].legend(['E','EM']) - axE[1].set_yscale('log') - axE[2].cla() - axE[2].plot([e[2] for e in Esave]) - axE[2].legend(['ER']) - axE[2].set_yscale('log') - - - fig.canvas.draw() - figE.canvas.draw() + if display: + ax[0].cla() + Ishow = ((AI-torch.amin(AI,(1,2,3))[...,None,None])/(torch.amax(AI,(1,2,3))-torch.amin(AI,(1,2,3)))[...,None,None,None]).permute(1,2,3,0).clone().detach().cpu() + ax[0].imshow( Ishow[0,...,0] ,extent=extentJ) + #ax[0].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) + ax[0].set_title('space tformed source') + + ax[1].cla() + Ishow = clip(fAI.permute(1,2,3,0).clone().detach()/torch.max(J).item()).cpu() + ax[1].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1) + #ax[1].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) + ax[1].set_title('contrast tformed source') + + ax[5].cla() + Ishow = clip( (fAI - J)/(torch.max(J).item())*3.0 ).permute(1,2,3,0).clone().detach().cpu()*0.5+0.5 + ax[5].imshow(Ishow[0,...,0],extent=extentJ) + #ax[5].scatter(pointsIt[:,1].clone().detach().cpu(),pointsIt[:,0].clone().detach().cpu()) + #ax[5].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) + ax[5].set_title('Error') + + ax[2].cla() + Ishow = J.permute(1,2,3,0).cpu()/torch.max(J).item() + ax[2].imshow(Ishow[0,...,0],extent=extentJ,vmin=0,vmax=1) + #ax[2].scatter(pointsJ[:,1].clone().detach().cpu(),pointsJ[:,0].clone().detach().cpu()) + ax[2].set_title('Target') + + ax[4].cla() + ax[4].imshow(clip(torch.stack((WM,WA,WB),-1).clone().detach()).cpu()[0],extent=extentJ) + ax[4].set_title('Weights') + + + toshow = v[0].clone().detach().cpu() # initial velocity, components are rgb + toshow /= torch.max(torch.abs(toshow)) + toshow = toshow*0.5+0.5 + #toshow = torch.cat((toshow,torch.zeros_like(toshow[...,0][...,None])),-1) + ax[3].cla() + ax[3].imshow(clip(toshow)[toshow.shape[0]//2],extent=extentV) + ax[3].set_title('velocity') + + axE[0].cla() + axE[0].plot(Esave) + axE[0].legend(['E','EM','ER','EP']) + axE[0].set_yscale('log') + axE[1].cla() + axE[1].plot([e[:2] for e in Esave]) + axE[1].legend(['E','EM']) + axE[1].set_yscale('log') + axE[2].cla() + axE[2].plot([e[2] for e in Esave]) + axE[2].legend(['ER']) + axE[2].set_yscale('log') + + + fig.canvas.draw() + figE.canvas.draw() return { 'A': A.clone().detach(), @@ -1626,7 +1638,8 @@ def LDDMM_3D_to_slice(xI,I,xJ,J,pointsI=None,pointsJ=None, 'WM': WM.clone().detach(), 'WB': WB.clone().detach(), 'WA': WA.clone().detach(), - 'Xs': Xs.clone().detach() + 'Xs': Xs.clone().detach(), + 'Errors': Esave, }