import numpy as np
import os
from matplotlib import colors
import matplotlib.pyplot as plt
#from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.tri as tri
#from dolfin import *
from mpl_toolkits.axes_grid1 import make_axes_locatable

import colormaps as cmaps 

def plotEtaWithAdaptation( etas, nMeshes, path, direcName, adapt, edgecolors="k"): 
   """
   
   MAIN ROUTINE
            0			1 		2		3		7
   etas = ['etaAver', 'etaS', 'etaD', 'etaA', 'etaAD']
   
   from file etaName-***** 
   read the values and plot each on the for meshName 
   which is given in 
   triA0000* - atd from Adgfem # TODO - replace by TriA*.vtk
   TriA0000*.vtk contains the etas
   nMeshes - 0-nMeshes of grids are used 
   path - path to the directory when we computed the Adgfem results
   direcName - folder where the plots will be saved 
   name - beginning of the .eps file names    
   """
   
   if ( not adapt and nMeshes > 0):
   	print "Problem for NO adaptation there is only one mesh!"
   	return 
   
   
   direc = os.path.join(path, direcName)
   # has to correspond to the containment of the file TriA0000
   etaNames = ('etaAver', 'etaS', 'etaD', 'etaA', 'etaAD', 'etaMark')
   etaIndexes = {'etaAver': 0, 'etaS': 1, 'etaD': 2, 'etaA': 3, 'etaAD': 7, \
   				 'etaMark': 29 }
   
   xLimits1 = [-20,20]
   yLimits1 = [-20,20]
   limits1 = ( xLimits1, yLimits1 )
   
   xLimits2 = [-0.25,1.25]
   yLimits2 = [-0.5,0.5] 
   limits2  = ( xLimits2, yLimits2 )
  
#   xLimits3 = [-0.25,0.5]
#   yLimits3 = [-0.25,0.25]     
#   limits3  = (xLimits3, yLimits3 )
   
   xLimits3 = [-0.1,0.1]
   yLimits3 = [-0.1,0.1]     
   limits3  = (xLimits3, yLimits3 )
   
   xLimits4 = [0.0,1.0]
   yLimits4 = [0.0,1.0]     
   limits4  = (xLimits4, yLimits4 )
   
   limits = ( limits1, limits2, limits3 )

   # make the folder if neccessary
   if not os.path.isdir( direc ): 
      os.makedirs(direc) 
  		   
			# 0 - nMeshes
   for i in range(nMeshes+1):
   		meshName, vtkName = setMeshNames(adapt, path,i) 
   		triangulation = readMesh( meshName ) 
   		
	 	for eta in etas: 
   			if eta in etaNames: 
   				etaIndex = etaIndexes[eta] 
				etaVals = readEtaFromParaviewMeshFile(vtkName, etaIndex)   				
   				for jZoom in range(len(limits)):    					   				
	   				etaName = setEtaNames(direc, eta, i, jZoom)			 	
					plotConstWithLimits( etaName, etaVals, triangulation, \
						limits[jZoom], edgecolors = edgecolors)	
				#plotConstWithLimits( etaName, etaVals, triangulation, \
				#		xLimits, yLimits )
				      
   return  
   
	
	
def setMeshNames(adapt, path,i): 
	if adapt:
		conn = "A" 
	else: 
		conn = "-"
	
	if ( i < 10): 
		meshName = 'tri' + conn + '0000' + str(i)
		vtkName =  'Tri' + conn + '0000' + str(i) + ".vtk"

	elif (i< 100):
		meshName = 'tri' + conn + '000' + str(i)
		vtkName =  'Tri' + conn + '000' + str(i) + ".vtk"
	elif (i<1000): 
		meshName = 'tri' + conn + '00' + str(i)
		vtkName =  'Tri' + conn + '00' + str(i) + ".vtk"
	elif (i<10000):
		meshName = 'tri' + conn + '0' + str(i)
		vtkName =  'Tri' + conn + '0' + str(i) + ".vtk"
	else:
		meshName = 'tri' + conn + '' + str(i)
		vtkName =  'Tri' + conn + '' + str(i) + ".vtk"
		
	vtkName = os.path.join(path,vtkName)
	meshName = os.path.join(path,meshName)

	return meshName, vtkName

def setEtaNames(direc, etaPrefix, i, jZoom): 
	if ( i < 10): 
	 etaName = etaPrefix + '_0000' 	+ str(i) + '_' + str(jZoom) + '.eps'
	elif (i< 100):
	 etaName = etaPrefix + '_000' 	+ str(i) + '_' + str(jZoom) + '.eps'
	elif (i<1000): 
	 etaName = etaPrefix + '_00' 	+ str(i) + '_' + str(jZoom) + '.eps'
	elif (i<10000):
	 etaName = etaPrefix + '_0' 	+ str(i)  + '_' + str(jZoom) + '.eps'
	else:
	 etaName = etaPrefix + '_' 		+ str(i)  + '_' + str(jZoom) + '.eps'
	etaName = os.path.join(direc,etaName)
	return etaName


# each line of the file 
def readEtaFile( fileName ): 
   """
   etaFile contains on each line element values of the estimate 
   """
   v = []
   with open(fileName) as f:   
      for line in f: 
         v.append([float(x) for x in line.split()] )         
   return v 
   
def readEtaFromParaviewMeshFile( fileName, index ): 
	"""
	triA*****.vtk contains the etas after the lines:
	SCALARS etas double    30
	LOOKUP_TABLE default 	
	the index gives the column of eta 
	eta_aver =0, eta_S = 1 ,eta_D = 2, eta_A = 3, eta_AD = 7
	"""
	v = []
	found = False
	with open(fileName) as f:   
		lines=f.readlines()
	firstIndex = -1
	for line in lines: 
		if "SCALARS etas double" in line: 
			firstIndex = lines.index(line)
	etas = []
	with open(fileName) as f: 
		for i, line in enumerate(f):	
			if i > firstIndex+1: 
				etas.append([float(str(x)) for x in line.split()]) 
	if etas[-1]== []: 
		etas.pop() # get rid of the last empty row
	#print "etas = " , etas[0]
	
	eta = []
	for e in etas: 
		eta.append(e[index])
		
#	print "eta = " , eta[0]
                
   	return eta
	
   

def readMesh( fileName ):
   """
   read mesh from the file like tri-00000 and set tri.Triangulation
   """
   with open(fileName) as f: 
      npoin, nelem, nbelm, nbc = \
      [int(x) for x in next(f).split()] # read the first line
      l1 = \
      [float(x) for x in next(f).split()] # not used 
      
      coord = []
      for i in range(npoin): 
         coord.append([float(x) for x in next(f).split()]) 
      # set the right format needed for tri
      xy = np.asarray( coord )             

      triangle = []      
      for i in range(nelem): 
         triangle.append([int(x)-1 for x in next(f).split()])       

      triangle = np.asarray( triangle )       
      triangulation = tri.Triangulation(xy[:, 0], xy[:, 1], triangle)

      f.close()            
   return triangulation


def writeArray( array, full_path ): 
   with open(full_path, 'w') as f:
      f.write( str( array ) ) 
      f.close() 
   return 

 
def plotConst( full_path, v, triangulation):
   """
   Makes colormap for a piecewise constant function on ELEMENTS of mesh with 
   v - 1D array of values 1:nelem  
   
   """ 
   v = np.asarray( v ) 
 
   vmin = v.min()
   vmax = v.max() 
   
   v[v < vmin] = vmin + 1e-12
   v[v > vmax] = vmax - 1e-12
   
   # type of color map
   cmap = cmaps.parula 
   
   # Plot the mesh
   #   plt.figure()
   #   plt.triplot(triangulation)
   #   plt.show()  
   plt.close("all")
   plt.ioff()
        
   plt.tripcolor(triangulation.x, triangulation.y, triangulation.triangles, \
               facecolors=v, edgecolors='k', cmap = cmap)

   plt.colorbar()
   plt.savefig( full_path , bbox_inches='tight' )
   #plt.savefig( full_path + '.eps', dpi=300)
   
   print 'Colormap was successfuly plotted!'
   return 
   
   
def plotConstWithLimits( full_path, v, triangulation, limits =([],[]), \
						 edgecolors='none'):
   """
   Makes colormap for a piecewise constant function on ELEMENTS of mesh with 
   v - 1D array of values 1:nelem     
   """ 
   
   xLimits = limits[0]
   yLimits = limits[1]
   # type of color map
   cmap = cmaps.parula 

   plt.close("all")
   plt.ioff()
   
   v = np.asarray( v )    
   x = triangulation.x 
   y = triangulation.y 
   # wrapper arounf the chosen domain - if too small -> white triangles will appear
   eps = 0.0 
   
   
   
   if (xLimits != [] and yLimits!= []): 
   	   
   	   newMask = []
   	   for triangle in triangulation.triangles: 
   	   		xx = x[triangle]
   	   		yy = y[triangle] 
   	   		isInside = False
   	   		for i in range(len(xx)): 
   	   			if xLimits[0] <= xx[i] <= xLimits[1] and \
	   	   			yLimits[0] <= yy[i] <= yLimits[1]:
	   	   			isInside = True
	   	   	newMask.append(int(isInside))	
	   	 
			
   	   
	   xmid = x[triangulation.triangles].mean(axis=1)
	   ymid = y[triangulation.triangles].mean(axis=1)
	   
	   msk1 = [(el>xLimits[0]-eps) and (el<=xLimits[1]+eps) for el in xmid]
	   msk2 = [(el>yLimits[0]-eps) and (el<=yLimits[1]+eps) for el in ymid] 
	   mask = np.minimum(msk1,msk2)
	   mask = np.logical_not(mask)
	   newMask = np.logical_not(newMask)
	      
	   # to use only the values of v in the right position - range of the graph is more acurate
#	   triangulation.set_mask(mask)
	   triangulation.set_mask(newMask)
	   axes = plt.gca()
	   axes.axis('scaled')
	   axes.set_xlim(xLimits)
	   axes.set_ylim(yLimits)
   
  
   norm        = colors.LogNorm()   
   plt.tripcolor(triangulation, \
               facecolors=v, cmap = cmap, \
               norm= norm, edgecolors=edgecolors) #edgecolors='k'

   divider = make_axes_locatable(plt.gca())
   cax  = divider.append_axes("right", "5%", pad="3%")
   cbar = plt.colorbar(cax=cax)
   #plt.colorbar()
   plt.savefig( full_path , bbox_inches='tight' )
#   #plt.savefig( full_path + '.eps', dpi=300)
   
   print 'Colormap was successfuly plotted!', full_path
   return    





   
   
   
def writeErrorsAndRelativeDifference( v, triangulation, direc, name ):   
   # for each line make one plot and save it 
   for i in range(len(v)):          
      # set name  
      full_path = os.path.join(direc, name) + str(i+1) + '.eps'
      #plot
      plotConst( full_path, v[i], triangulation )
      
   # plot relative difference between ith and last est
   for i in range(len(v)-1):          
         # set name  
         full_path = os.path.join(direc, name) + str(i+1) + '_diff.eps'
         #plot
         vv = [abs((x-y)/ x) for x,y in zip(v[-1], v[i])]
         plotConst( full_path, vv, triangulation)
   return 
   

   
   
def writeMarkingDifference( m, ratio, triangulation, direc, name, elDiff ): 
   """ 
   mark elements which would be refined in i-th step versus last step
    1 - right refinement and would be not refined due to alg errors
    0 - same 
   -1 - should not be refined, are refine only due to ALG errors
   """
   
   diffElements = []

   for i in range(len(m)):     
      # set name  
      full_path = os.path.join(direc, name) + str(i+1) + '_markDiff.eps'

      #plot  
      vv = [x - y for x,y in zip(m[-1], m[i])]
      diffElements.append( sum(1 for i in vv if i == 1) )
      nn = sum(1 for i in m[-1] if i == 1)
      print 'number of all marked elements = ', nn
      print 'The number of differently marked triangles is' , \
         diffElements[i] , '(ratio=',ratio, 'nelem = ', len(vv) ,')'
      plotConst( full_path, vv, triangulation )
         
   # add diffElements to the file elDiff 
   full_path = os.path.join(direc, elDiff) + '.txt'
   writeArray( diffElements, full_path )  
   
   return  
   
def marking( v , ratio): 
      """ 
      mark ( m[i] = 1 ) the ratio% of the elements to refinement 
      """
      # sort the array and set 1 - for 1st ratio% and 0 for the rest  
      m = [] # mask of refined elements
      for i in range(len(v)):          
            # sort 
            vv = sorted( v[i], reverse = True ) 
            lastIndex = int(len(vv)*ratio) 
            lastVal = vv[lastIndex] 
            mask = []
            for x in v[i]: 
               if (x > lastVal): 
                  mask.append( 1 ) 
               else:
                  mask.append( 0 )
                  
            m.append( mask ) 
      return m 
      

def etasPlot( etaName, meshName, direc, name, elDiff): 
   """
   
   MAIN ROUTINE
   
   from file etaName read the values and plot each on the fir meshName 
   which is given in the format of tri-00000 from Adgfem
   direc - folder where the plots will be saved 
   name - beginning of the .eps file names 
   """
   
   # read etas from file and prepare mesh 
   v = readEtaFile( etaName )     
   triangulation = readMesh( meshName ) 
   
   # make the folder if neccessary
   if not os.path.isdir( direc ): 
      os.makedirs(direc) 
      
   # plot the etas and relative difference from the last (v[-1])
   writeErrorsAndRelativeDifference( v, triangulation, direc, name )
         
   # plot mesh refinement 
   ratio = 0.2  # percentage of refined elements with the highest errors        
   # mark ( m[i] = 1 ) the ratio% of the elements to refinement 
   m = marking(v, ratio )
   
   # elements which would be refined due to alg errors 
   writeMarkingDifference( m, ratio, triangulation, direc, name, elDiff )
   
   return  
    
     
   
############################################################################

#direc    = 'colormaps'
#meshName = 'tri-00000'

#etaSName  = 'etaS_plot'
#elSdiff = 'elem_diff_S'
#nameS     = 'etaS_Color_'

#etaDName  = 'etaD_plot'
#nameD     = 'etaD_Color_'
#elDdiff = 'elem_diff_D'

#etasPlot( etaSName, meshName, direc, nameS, elSdiff )
#etasPlot( etaDName, meshName, direc, nameD, elDdiff  )
