#!/usr/bin/env python3

from pNbody import orbitslib
from numpy import *

import sys

import Ptools as pt

from scipy import optimize

from pNbody import *

from optparse import OptionParser

try:
  import SM
except:
  pass  

##############################################
# option parser
##############################################

def parse_options():

  usage = "usage: %prog [options] file"
  parser = OptionParser(usage=usage)



  parser.add_option("--plotpotential",
  		    action="store_true", 
  		    dest="plotpotential",
		    default = False,		    
  		    help="plot the potential")	

  parser.add_option("--noplot",
  		    action="store_true", 
  		    dest="noplot",
		    default = False,		    
  		    help="do not plot")	

  parser.add_option("--nlaps",
  		    action="store", 
  		    dest="nlaps",
		    type="int",
		    default = 100,		    
  		    help="number of laps for one orbit")

  parser.add_option("--norbits",
  		    action="store", 
  		    dest="norbits",
		    type="int",
		    default = 10,		    
  		    help="number of orbints for one given energy")

  # harmonic potential wx, wy, wz

  parser.add_option("--wx",
  		    action="store", 
  		    dest="wx",
		    type="float",
		    default = 0,		    
  		    help="potential frequency in x")


  parser.add_option("--wy",
  		    action="store", 
  		    dest="wy",
		    type="float",
		    default = 0,		    
  		    help="potential frequency in y")

  parser.add_option("--wz",
  		    action="store", 
  		    dest="wz",
		    type="float",
		    default = 0,		    
  		    help="potential frequency in z")


  # point mass potential GM_pm

  parser.add_option("--GM_pm",
  		    action="store", 
  		    dest="GM_pm",
		    type="float",
		    default = 0,		    
  		    help="pm total mass")
		    

  # plummer potential GM_plummer, e

  parser.add_option("--GM_plummer",
  		    action="store", 
  		    dest="GM_plummer",
		    type="float",
		    default = 0,		    
  		    help="plummer total mass")
		    
  parser.add_option("--e",
  		    action="store", 
  		    dest="e",
		    type="float",
		    default = 0.1,		    
  		    help="Plummer potential softening")

  # miyamoto potential GM_miyamoto, a,b

  parser.add_option("--GM_miyamoto",
  		    action="store", 
  		    dest="GM_miyamoto",
		    type="float",
		    default = 0,		    
  		    help="miyamoto total mass")
		    
  parser.add_option("--a",
  		    action="store", 
  		    dest="a",
		    type="float",
		    default = 2.9,		    
  		    help="miyamoto a parameter")

  parser.add_option("--b",
  		    action="store", 
  		    dest="b",
		    type="float",
		    default = 0.1,		    
  		    help="miyamoto b parameter")

  # logarithmic potential V0, q, p, Rc

  parser.add_option("--V0",
  		    action="store", 
  		    dest="V0",
		    type="float",
		    default = 0.,		    
  		    help="logarithmic V0 parameter")

  parser.add_option("--q",
  		    action="store", 
  		    dest="q",
		    type="float",
		    default = 0.8,		    
  		    help="logarithmic q parameter (divide y)")

  parser.add_option("--p",
  		    action="store", 
  		    dest="p",
		    type="float",
		    default = 1.0,		    
  		    help="logarithmic p parameter (divide z)")

  parser.add_option("--Rc",
  		    action="store", 
  		    dest="Rc",
		    type="float",
		    default = 0.1,		    
  		    help="logarithmic Rc parameter")


  parser.add_option("--Lz",
  		    action="store", 
  		    dest="Lz",
		    type="float",
		    default = 0.0,		    
  		    help="z component of the angular momentum (set integration in (R-z) plane)")


  # rotation Omega

  parser.add_option("--Omega",
  		    action="store", 
  		    dest="Omega",
		    type="float",
		    default = 0.0,		    
  		    help="angular frequency")





  parser.add_option("-E",
  		    action="store", 
  		    dest="E",
		    type="float",
		    default = 10.,		    
  		    help="orbit energy")	


  parser.add_option("--x",
  		    action="store", 
  		    dest="x",
		    type="float",
		    default = None,		    
  		    help="orbit initial x")

  parser.add_option("--vx",
  		    action="store", 
  		    dest="vx",
		    type="float",
		    default = None,		    
  		    help="orbit initial vx")


  parser.add_option("--vxfrac",
  		    action="store", 
  		    dest="vxfrac",
		    type="float",
		    default = 0.,		    
  		    help="fraction of max velocity in vx")	
		    
		    

  parser.add_option("--xmax",
  		    action="store", 
  		    dest="xmax",
		    type="float",
		    default = None,		    
  		    help="xmax")

  parser.add_option("--xmin",
  		    action="store", 
  		    dest="xmin",
		    type="float",
		    default = None,		    
  		    help="xmin")

  parser.add_option("--add_ILz",
  		    action="store_true", 
  		    dest="add_ILz",
		    default = False,		    
  		    help="add the curve due to Lz")
          
  parser.add_option("--add_Ix",
  		    action="store_true", 
  		    dest="add_Ix",
		    default = False,		    
  		    help="add the curve due to Ix")

		    		    		    		    
  (options, args) = parser.parse_args()

        
  if len(args) == 0:
    file = None
  else:
    file = args[0]
  
  return file,options







##############################################
# system functions
##############################################

file,opt = parse_options()



# rk78 parameters
dt   = 1e-5
epsx = 1.e-15
epsv = 1.e-15

# potential parameters
wx2   = opt.wx**2;
wy2   = opt.wy**2;
wz2   = opt.wz**2;
GM_pm = opt.GM_pm;
GM_plummer = opt.GM_plummer;
e     = opt.e;
GM_miyamoto = opt.GM_miyamoto;
a     = opt.a;
b     = opt.b;

V0    = opt.V0;
q     = opt.q;
p     = opt.p;
Rc    = opt.Rc;

Lz    = opt.Lz;

Omega = opt.Omega;

xmax   = opt.xmax


def Potential(x,y,z):
  
  r2 = (x**2 + y**2 + z**2)
  r  = sqrt(r2)
  
  R2 = (x**2 + y**2)
  R  = sqrt(R2)
  
  pot = 0.5*wx2*x**2 + 0.5*wy2*y**2 + 0.5*wz2*z**2 
  if GM_pm>0:
    pot = pot - GM_pm / r
  
  if GM_plummer>0:  
    pot = pot - GM_plummer/sqrt(r2+e**2)
  
  if GM_miyamoto>0:
    pot = pot - GM_miyamoto/sqrt( R2 + ( a + sqrt(z**2+b**2)  )**2 )

  if V0 > 0:
    pot = pot + 0.5*V0**2 * log(Rc**2 + x**2 + y**2/q**2 + z**2/p**2)

  if Omega != 0:
    pot = pot -0.5*Omega**2*(x**2 + y**2)
    
  return pot


def Energy(x,y,z,vx,vy,vz):
    
  r2 = (x**2 + y**2 + z**2)
  
  e = 0.5*(vx**2 + vy**2 + vz**2)
  e = e +  Potential(x,y,z)
    
  return e


def PlotPotential(xmax=5):


    x = arange(-xmax*1.1,xmax*1.1,0.001)
    pt.plot(x,Potential(x,0,0),'r')
  
    y = arange(Energy(0,0,0,0,0,0)*1.,Energy(xmax,0,0,0,0,0)*1+0.01,0.01)
    x = xmax*ones(len(y))
    pt.plot(x,y,'b--')  
    
    x = -xmax*ones(len(y))
    pt.plot(x,y,'b--')  
  
    pt.show()
  


def Plot2dPotential(xmax=5):

    n = 100
    dx = 2*xmax/100.

    x = arange(-xmax,xmax,dx)
    y = arange(-xmax,xmax,dx)

    mat = zeros((len(x),len(y)))
    xs = []
    ys = []
    zs = []

    for i in range(len(x)):
      for j in range(len(y)):
        
        mat[j,i] = Potential(x[i],y[j],0)

        xs.append(x[i])
        ys.append(y[j])
        zs.append(mat[j,i])
  
    im = pt.imshow(mat, interpolation='bilinear',origin='lower',extent=(-xmax,xmax,-xmax,xmax))
    
    
    matmin = min(ravel(mat))
    matmax = max(ravel(mat))
    dm = (matmax-matmin)/50.
    c = arange(matmin,matmax,dm)
    #pt.contour(mat,c)
    
    pt.show()
    
    # create pNbody object
    xs = array(xs)
    ys = array(ys)
    zs = array(zs)
    pos = pos = transpose(array([xs,ys,zs]))
    nb = Nbody(pos=pos.astype(float32),status="new",ftype='gadget')
    nb.rename('surface.dat')
    nb.write()
    


   
def ComputeXmax(E):


  def DEnergyx(x,E):
    y = 0
    z = 0
    e = Potential(x,y,z)    
    return e-E

  b = 1000
  a = 0.0
  if Omega>0 and V0>0:
    b =  sqrt( (V0**2 - Omega**2*Rc**2)/Omega**2 )
    Emax = Potential(b,0,0)
    print("EnergyCR     = %g"%(Emax))
    print("RadiusCR     = %g"%(b))
    if E>Emax:
      return 2
    
  #return fabs(optimize.newton(DEnergyx, x0=5, args = (E,), fprime = None, tol = 1e-20, maxiter = 500))    
  return fabs(optimize.bisect(DEnergyx, a=a,b=b, args = (E,), xtol = 1e-20, maxiter = 500))    
  


def ComputeVmax(E):  
  return sqrt(2*(E-Potential(0,0,0)))


def ComputeVmaxs(E,x):  
  return sqrt(2*(E-Potential(x,0,0)))






####################################################################
# MAIN
####################################################################


E = opt.E


EnergyMin = Potential(0,0,0)
print("EnergyMin    = %g"%(EnergyMin))

if xmax == None:
  xmax	  = ComputeXmax(E)
  print("Max Radius   = %g"%(xmax))

vmax	  = ComputeVmax(E)
print("Max Velocity = %g"%(vmax))


if opt.plotpotential:
  PlotPotential(xmax=xmax)
  Plot2dPotential(xmax=xmax)
  sys.exit()

if opt.xmin!=None:
  xmin = opt.xmin
else:
  xmin = -xmax


if opt.x==None:
  dx = (xmax-xmin)/opt.norbits
  xs = arange(xmin+dx/2.,xmax,dx)
else:
  xs = array([opt.x],float)
  opt.norbits = 1

poss = zeros((opt.norbits*opt.nlaps,3),float)
vels = zeros((opt.norbits*opt.nlaps,3),float)


xi  = zeros((opt.norbits,),float)
vxi = zeros((opt.norbits,),float)

##################################
# loop over x
##################################
n = 0
for i,x in enumerate(xs):

  y  = 0.0
  z  = 0.0
  vz = 0.0
  
  if opt.vx==None:
    vx = opt.vxfrac*vmax
  else:
    vx = opt.vx

  if (x>xmax):
    raise "x>xmax"
    
  if (vx>ComputeVmaxs(E,x)):
    raise "vx>vxmax(x)=%g"%(ComputeVmaxs(E,x) )
  
      
  vy = sqrt(2*(E-Potential(x,y,z))-vx**2)
    
  print("(check) Energy for x=%g vx=%g (vxmax=%g): = %g"%(x,vx,ComputeVmaxs(E,x),Energy(x,y,z,vx,vy,vz)))


  # transform into canonical coord
  vx = vx - Omega*y
  vy = vy + Omega*x
  vz = vz


  pos = array([[x,y,z]],float)        
  vel = array([[vx,vy,vz]],float)
  mass= array([0],float)
  
  
  xi[i]  = x
  vxi[i] = vx
  

  if opt.norbits==1:
    posis = array([],float)
    velis = array([],float)
    posis.shape = (0,3)
    velis.shape = (0,3)

  ##################################
  # compute opt.nlaps laps of orbit
  ##################################
  nlapts = 0
  while (nlapts<opt.nlaps):
    
    pos,vel,posi,veli,posz,velz,atime,dt = orbitslib.IntegrateOneOrbitUsingForces(pos,vel,mass,wx2,wy2,wz2,GM_pm,GM_plummer,e,GM_miyamoto,a,b,V0,Rc,q,p,Lz,Omega,epsx,epsv,dt)
	
    poss[n] = posz
    vels[n] = velz
    n = n + 1
    
    if opt.norbits==1:
      
      r = posi[:,0]**2 +posi[:,1]**2 +posi[:,2]**2 
      posi = compress(r>0,posi,axis=0)
      r = veli[:,0]**2 +veli[:,1]**2 +veli[:,2]**2 
      veli = compress(r>0,veli,axis=0)
      
      posis = concatenate((posis,posi))
      velis = concatenate((velis,veli))
    
    nlapts = nlapts + 1

  

# create an nbody object with posi and veli
if opt.norbits==1:
  nb = Nbody(pos=posis.astype(float32),vel=velis.astype(float32),status="new",ftype='gadget')
  nb = nb.selectc(nb.rxyz()>0)
  nb.rename('orbit.dat')
  nb.write()
    

x  = poss[:,0] 
y  = poss[:,1] 
z  = poss[:,2]
vx = vels[:,0] + Omega*y
vy = vels[:,1] - Omega*x 
vz = vels[:,2]


# create an nbody object with the surface of section (x,vx,E)
dx = 2*xmax/100.
xs = arange(-xmax,xmax+dx,dx)
vs = ComputeVmaxs(E,xs)

newx = concatenate((x,xs))
newx = concatenate((newx,xs))
newy = concatenate((vx,vs))
newy = concatenate((newy,-vs))
newz = E*ones(len(newx))

pos = transpose([newx,newy,newz])
nb = Nbody(pos=pos.astype(float32),status="new",ftype='gadget')
nb.rename("surf%07.3f.dat"%(E))
nb.write()


  


############################
# plot
############################

if not opt.noplot:

  pt.figure()

  #pt.subplot(2,2,1)
  #pt.scatter(vy,vx,s=1)
  #pt.xlabel('vy')
  #pt.ylabel('vx')
  #pt.axis([0,vmax,-vmax,vmax])

  pt.subplot(1,1,1)

  dx = 2*xmax/100.
  xs = arange(-xmax,xmax+dx,dx)
  vs = ComputeVmaxs(E,xs)

  pt.scatter(x,vx,s=1)
  pt.plot(xs, vs,'r')
  pt.plot(xs,-vs,'r')
  pt.xlabel('x')
  pt.ylabel('vx')
  pt.axis([-xmax,xmax,-vmax,vmax])

  #pt.subplot(2,2,3)
  #pt.scatter(vy,x,s=1)
  #pt.xlabel('vy')
  #pt.ylabel('x')
  #pt.axis([0,vmax,-xmax,xmax])
  
  # mark ic with crosses
  #pt.scatter(xi,vxi,marker='x',color='k')
  
  # angular momentum conservation
  if opt.add_ILz :
  
    '''
    def getRminmax():
      return E-Potential(xs,0,0) - (Lz**2/x**2)
    '''

    Lzs = linspace(0.1,0.4,10)
    xmin = 0
    xmax = xmax
    
    
    for Lz in Lzs:

      xs = linspace(xmin,xmax,1000)
      xp = sqrt( (2*xs**2 * (E-Potential(xs,0,0)) - Lz**2) /xs**2 )
      c = isfinite(xp)
      xs = compress(c,xs)
      xp = compress(c,xp)
      
      pt.plot(xs, xp,'g')
      pt.plot(xs,-xp,'g')
      
      pt.plot(-xs, xp,'g')
      pt.plot(-xs,-xp,'g')
      
      #Lz = x*vy - y*vx
      #print(Lz)

  # hamiltonian for y=0 vy=0
  if opt.add_Ix :

    Exs = linspace(EnergyMin,E*1.05,10)
    xmin = -xmax
    xmax = xmax
    
    
    for Ex in Exs:
      
      xs = linspace(xmin,xmax,1000)
      xp = sqrt(2*(Ex-Potential(xs,0,0)))
            
      c = isfinite(xp)
      xs = compress(c,xs)
      xp = compress(c,xp)
            
      pt.plot(xs, xp,'y')
      pt.plot(xs,-xp,'y')






if opt.norbits==1:

  pt.figure()

  px = posis[:,0]
  py = posis[:,1]
  
  xmin = min(px)
  xmax = max(px)
  ymin = min(py)
  ymax = max(py)
  xmax = max(xmax,ymax,fabs(xmin),fabs(ymin))*1.3
      
  pt.subplot(1,1,1)
  pt.plot(px,py)
  pt.xlabel('x')
  pt.ylabel('y')
  pt.axis([-xmax,xmax,-xmax,xmax])
    
  ax = pt.gca()
  ax.set_aspect('equal', 'box')



pt.show()






