#!/usr/bin/env python3

from pNbody import orbitslib
from numpy import *

import sys

import Ptools as pt

from scipy import optimize

from pNbody import *

from optparse import OptionParser

try:
  import SM
except:
  pass  

##############################################
# option parser
##############################################

def parse_options():

  usage = "usage: %prog [options] file"
  parser = OptionParser(usage=usage)



  parser.add_option("--plotpotential",
  		    action="store_true", 
  		    dest="plotpotential",
		    default = False,		    
  		    help="plot the potential")	

  parser.add_option("--noplot",
  		    action="store_true", 
  		    dest="noplot",
		    default = False,		    
  		    help="do not plot")	

  parser.add_option("--nlaps",
  		    action="store", 
  		    dest="nlaps",
		    type="int",
		    default = 100,		    
  		    help="number of laps for one orbit")

  parser.add_option("--norbits",
  		    action="store", 
  		    dest="norbits",
		    type="int",
		    default = 10,		    
  		    help="number of orbints for one given energy")

  # harmonic potential wx, wy, wz

  parser.add_option("--wx",
  		    action="store", 
  		    dest="wx",
		    type="float",
		    default = 0,		    
  		    help="potential frequency in x")


  parser.add_option("--wy",
  		    action="store", 
  		    dest="wy",
		    type="float",
		    default = 0,		    
  		    help="potential frequency in y")

  parser.add_option("--wz",
  		    action="store", 
  		    dest="wz",
		    type="float",
		    default = 0,		    
  		    help="potential frequency in z")


  # point mass potential GM_pm

  parser.add_option("--GM_pm",
  		    action="store", 
  		    dest="GM_pm",
		    type="float",
		    default = 0,		    
  		    help="pm total mass")
		    

  # plummer potential GM_plummer, e

  parser.add_option("--GM_plummer",
  		    action="store", 
  		    dest="GM_plummer",
		    type="float",
		    default = 0,		    
  		    help="plummer total mass")
		    
  parser.add_option("--e",
  		    action="store", 
  		    dest="e",
		    type="float",
		    default = 0.1,		    
  		    help="Plummer potential softening")

  # miyamoto potential GM_miyamoto, a,b

  parser.add_option("--GM_miyamoto",
  		    action="store", 
  		    dest="GM_miyamoto",
		    type="float",
		    default = 0,		    
  		    help="miyamoto total mass")
		    
  parser.add_option("--a",
  		    action="store", 
  		    dest="a",
		    type="float",
		    default = 2.9,		    
  		    help="miyamoto a parameter")

  parser.add_option("--b",
  		    action="store", 
  		    dest="b",
		    type="float",
		    default = 0.1,		    
  		    help="miyamoto b parameter")

  # logarithmic potential V0, q, p, Rc

  parser.add_option("--V0",
  		    action="store", 
  		    dest="V0",
		    type="float",
		    default = 0.,		    
  		    help="logarithmic V0 parameter")

  parser.add_option("--q",
  		    action="store", 
  		    dest="q",
		    type="float",
		    default = 0.8,		    
  		    help="logarithmic q parameter (divide y)")

  parser.add_option("--p",
  		    action="store", 
  		    dest="p",
		    type="float",
		    default = 1.0,		    
  		    help="logarithmic p parameter (divide z)")

  parser.add_option("--Rc",
  		    action="store", 
  		    dest="Rc",
		    type="float",
		    default = 0.1,		    
  		    help="logarithmic Rc parameter")


  parser.add_option("--Lz",
  		    action="store", 
  		    dest="Lz",
		    type="float",
		    default = 0.0,		    
  		    help="z component of the angular momentum (set integration in (R-z) plane)")





  parser.add_option("-E",
  		    action="store", 
  		    dest="E",
		    type="float",
		    default = 10.,		    
  		    help="orbit energy")	


  parser.add_option("--R",
  		    action="store", 
  		    dest="R",
		    type="float",
		    default = None,		    
  		    help="orbit initial x")

  parser.add_option("--vR",
  		    action="store", 
  		    dest="vR",
		    type="float",
		    default = None,		    
  		    help="orbit initial vR")


  parser.add_option("--vRfrac",
  		    action="store", 
  		    dest="vRfrac",
		    type="float",
		    default = 0.,		    
  		    help="fraction of max velocity in vR")	
		    
		    

  parser.add_option("--Rmax",
  		    action="store", 
  		    dest="Rmax",
		    type="float",
		    default = None,		    
  		    help="Rmax")

  parser.add_option("--Rmin",
  		    action="store", 
  		    dest="Rmin",
		    type="float",
		    default = None,		    
  		    help="Rmin")


  parser.add_option("--add_IL",
  		    action="store_true", 
  		    dest="add_IL",
		    default = False,		    
  		    help="add the curve due to Ltot")

		    		    		    		    
  (options, args) = parser.parse_args()

        
  if len(args) == 0:
    file = None
  else:
    file = args[0]
  
  return file,options







##############################################
# system functions
##############################################

file,opt = parse_options()



# rk78 parameters
dt   = 1e-5
epsx = 1.e-15
epsv = 1.e-15

# potential parameters
wx2   = opt.wx**2;
wy2   = opt.wy**2;
wz2   = opt.wz**2;
GM_pm = opt.GM_pm;
GM_plummer = opt.GM_plummer;
e     = opt.e;
GM_miyamoto = opt.GM_miyamoto;
a     = opt.a;
b     = opt.b;

V0    = opt.V0;
q     = opt.q;
p     = opt.p;
Rc    = opt.Rc;


Lz    = opt.Lz;

Omega = 0;

Rmax   = opt.Rmax
Rmin   = opt.Rmin


def Potential(R,z):
  
  pot = 0
  
  if V0 > 0:
    pot = pot + 0.5*V0**2 * log(Rc**2 + R**2 + z**2/opt.p**2) + Lz**2/(2*R**2)
    
  return pot


def ComputeRg():
  
  def f(R):
    
    dPhidR = 0
    
    if V0!=0:
      dPhidR = dPhidR + V0**2 /(Rc**2 + R**2) * R
    
    # effective part
    dPhidR = dPhidR - Lz**2/R**3
        
    return dPhidR 
    
    
  a = 1e-3
  b = 1e3
  Rg = fabs(optimize.bisect(f, a=a,b=b, args = (), xtol = 1e-20, maxiter = 500))
  
  return Rg


def ComputeMinEnergy(Rg):
   return Potential(Rg,0)
   

  
  

def Energy(R,z,vR,vz):
      
  e = 0.5*(vR**2 + vz**2)
  e = e +  Potential(R,z)
    
  return e


def PlotPotential(Rmin=0.1,Rmax=0.5):


    R = linspace(Rmin/1.1,Rmax*1.1,100)
    pt.plot(R,Potential(R,0),'r')
    
    PhiRmin = Potential(Rmin,0)
    PhiRmax = Potential(Rmax,0)
    
    y = linspace(PhiRmin,min(Potential(R,0)),100)
    x = Rmin*ones(len(y))
    pt.plot(x,y,'b--')
    
    y = linspace(PhiRmax,min(Potential(R,0)),100)
    x = Rmax*ones(len(y))
    pt.plot(x,y,'b--')
    
    pt.xlabel("R")
    pt.ylabel("Effective Potential")
  
    pt.show()
  


def Plot2dPotential(xmax=5):

    n = 100
    dx = 2*xmax/100.

    x = arange(-xmax,xmax,dx)
    y = arange(-xmax,xmax,dx)

    mat = zeros((len(x),len(y)))
    xs = []
    ys = []
    zs = []

    for i in range(len(x)):
      for j in range(len(y)):
        
        mat[j,i] = Potential(x[i],y[j],0)

        xs.append(x[i])
        ys.append(y[j])
        zs.append(mat[j,i])
  
    im = pt.imshow(mat, interpolation='bilinear',origin='lower',extent=(-xmax,xmax,-xmax,xmax))
    
    
    matmin = min(ravel(mat))
    matmax = max(ravel(mat))
    dm = (matmax-matmin)/50.
    c = arange(matmin,matmax,dm)
    #pt.contour(mat,c)
    
    pt.show()
    
    # create pNbody object
    xs = array(xs)
    ys = array(ys)
    zs = array(zs)
    pos = pos = transpose(array([xs,ys,zs]))
    nb = Nbody(pos=pos.astype(float32),status="new",ftype='gadget')
    nb.rename('surface.dat')
    nb.write()
    


   
def ComputeRmax(E,Rg):

  def DEnergyx(R,E):
    z = 0
    e = Potential(R,z)    
    return e-E

  a = Rg
  b = 1e3

  #return fabs(optimize.newton(DEnergyx, x0=5, args = (E,), fprime = None, tol = 1e-20, maxiter = 500))    
  return fabs(optimize.bisect(DEnergyx, a=a,b=b, args = (E,), xtol = 1e-20, maxiter = 500))    

def ComputeRmin(E,Rg):

  def DEnergyx(R,E):
    z = 0
    e = Potential(R,z)    
    return e-E

  a = 1e-3
  b = Rg

  #return fabs(optimize.newton(DEnergyx, x0=5, args = (E,), fprime = None, tol = 1e-20, maxiter = 500))    
  return fabs(optimize.bisect(DEnergyx, a=a,b=b, args = (E,), xtol = 1e-20, maxiter = 500))    
  
  


def ComputeVmax(E,Rg):  
  return sqrt(2*(E-Potential(Rg,0)))


def ComputeVmaxs(E,R):  
  return sqrt(2*(E-Potential(R,0)))






####################################################################
# MAIN
####################################################################


E = opt.E



Rg = ComputeRg()
print("Rg    = %g"%(Rg))


EnergyMin = ComputeMinEnergy(Rg)
print("EnergyMin    = %g"%(EnergyMin))

if Rmax == None:
  Rmax	  = ComputeRmax(E,Rg)
  print("Max Radius   = %g"%(Rmax))

if Rmin == None:
  Rmin	  = ComputeRmin(E,Rg)
  print("Min Radius   = %g"%(Rmin))



vmax	  = ComputeVmax(E,Rg)
print("Max Velocity = %g"%(vmax))



if opt.plotpotential:
  PlotPotential(Rmin=Rmin,Rmax=Rmax)
  #Plot2dPotential(xmax=xmax)
  sys.exit()




if opt.R==None:
  dR = (Rmax-Rmin)/opt.norbits
  Rs = arange(Rmin+dR/2.,Rmax,dR)
else:
  Rs = array([opt.R],float)
  opt.norbits = 1



poss = zeros((opt.norbits*opt.nlaps,3),float)
vels = zeros((opt.norbits*opt.nlaps,3),float)

xi  = zeros((opt.norbits,),float)
vxi = zeros((opt.norbits,),float)
L2 = zeros((opt.norbits,),float)

##################################
# loop over R
##################################
n = 0
for i,R in enumerate(Rs):

  z  = 0.0
  p  = 0.0  # phi

  vz = 0.0
  vp = Lz   # = P_phi = Phi_point*R^2 = cte
    
  if opt.vR==None:
    vR = opt.vRfrac*vmax
  else:
    vR = opt.vR

  if (R>Rmax):
    raise "R>Rmax"
    
    
  if (vR>ComputeVmaxs(E,R)):
    raise "vR>vRmax(x)=%g"%(ComputeVmaxs(E,R) )
  
      
  vz = sqrt(2*(E-Potential(R,z))-0.5*vR**2)
    
  print("vR = ",vR)
  print("vz = ",vz)
  print("(check) Energy for R=%g vx=%g (vRmax=%g): = %g"%(R,vR,ComputeVmaxs(E,R),Energy(R,z,vR,vz)))

  
  # total angular momentum (should be checked...)
  L2[i] = (R**2)*(vz**2) + Lz**2

  
  xi[i]  = R
  vxi[i] = vR


  # transform into canonical coord

  x = R   # R
  y = z   # z
  z = p   # Phi
  
  vx = vR # vR
  vy = vz # vz
  vz = vp # P_phi = Phi_point*R^2 = cte


  pos = array([[x,y,z]],float)        
  vel = array([[vx,vy,vz]],float)
  mass= array([0],float)
  


  if opt.norbits==1:
    posis = array([],float)
    velis = array([],float)
    posis.shape = (0,3)
    velis.shape = (0,3)

  ##################################
  # compute opt.nlaps laps of orbit
  ##################################
  nlapts = 0
  while (nlapts<opt.nlaps):
    
    pos,vel,posi,veli,posz,velz,atime,dt = orbitslib.IntegrateOneOrbitUsingForces(pos,vel,mass,wx2,wy2,wz2,GM_pm,GM_plummer,e,GM_miyamoto,a,b,V0,Rc,opt.q,opt.p,Lz,Omega,epsx,epsv,dt)
	
    poss[n] = posz
    vels[n] = velz
    n = n + 1
    
    if opt.norbits==1:
      
      # remove points that were not used
      r = posi[:,0]
      posi = compress(r>0,posi,axis=0)
      veli = compress(r>0,veli,axis=0)
      
      posis = concatenate((posis,posi))
      velis = concatenate((velis,veli))
    
    nlapts = nlapts + 1


'''
r = posis[:,0]
z = posis[:,1]
vr = velis[:,0]
vz = velis[:,1]
Es = Energy(r,z,vr,vz)

pt.plot(r,Es)
pt.show()
#sys.exit()
'''


# create an nbody object with posi and veli
if opt.norbits==1:
  
  r = posis[:,0]
  
  print(r.min(),r.max())
  
  z = posis[:,1]
  p = posis[:,2]
  
  x = r*cos(p)
  y = r*sin(p)
  z = z
  pos = transpose(array([x, y, z]))
  
  nb = Nbody(pos=pos.astype(float32),status="new",ftype='gadget')
  #nb = Nbody(pos=pos.astype(float32),vel=velis.astype(float32),status="new",ftype='gadget')
  nb = nb.selectc(nb.rxyz()>0)
  nb.rename('orbit.dat')
  nb.write()




R = poss[:,0]
vR = vels[:,0]

# create an nbody object with the surface of section (x,vx,E)
Rs = linspace(Rmin,Rmax,1000)
vs = ComputeVmaxs(E,Rs)

newx = concatenate((R,Rs))
newx = concatenate((newx,Rs))

newy = concatenate((vR,vs))
newy = concatenate((newy,-vs))

newz = E*ones(len(newx))


pos = transpose([newx,newy,newz])
nb = Nbody(pos=pos.astype(float32),status="new",ftype='gadget')
nb.rename("surf%07.3f.dat"%(E))
nb.write()


  


############################
# plot
############################

x = poss[:,0]
vx = vels[:,0]

if not opt.noplot:

  pt.figure()

  #pt.subplot(2,2,1)
  #pt.scatter(vy,vx,s=1)
  #pt.xlabel('vy')
  #pt.ylabel('vx')
  #pt.axis([0,vmax,-vmax,vmax])

  pt.subplot(1,1,1)

  Rs = linspace(Rmin,Rmax,10000)
  vs = ComputeVmaxs(E,Rs)
  

  pt.scatter(xi,vxi,marker='x',color='k')
    
  
  pt.scatter(x,vx,s=1)
  pt.plot(Rs, vs,'r')
  pt.plot(Rs,-vs,'r')
  pt.xlabel('R')
  pt.ylabel('vR')
  Rmin = 0
  Rmax = 0.5
  vmax = 1.
  pt.axis([Rmin,Rmax,-vmax,vmax])

  #pt.subplot(2,2,3)
  #pt.scatter(vy,x,s=1)
  #pt.xlabel('vy')
  #pt.ylabel('x')
  #pt.axis([0,vmax,-xmax,xmax])

  
  # plot I3  
  if opt.add_IL:
    Rs = linspace(x.min(),x.max(),1000)
    for i in range(len(L2)):
      vRp =  np.sqrt( 2*(E - Potential(Rs,0) -  (1/(2*Rs**2))*(L2[i]-Lz**2)  ))
      vRm = -vRp
      pt.plot(Rs,vRp,'-g')
      pt.plot(Rs,vRm,'-g')




if opt.norbits==1:

  fig, axs = pt.subplots(2, 1)

  px = posis[:,0]
  py = posis[:,1]
  
  xmin = min(px)
  xmax = max(px)
  ymin = min(py)
  ymax = max(py)
  xmax = max(xmax,ymax,fabs(xmin),fabs(ymin))*1.3
      
  pt.subplot(1,1,1)
  pt.plot(px,py)
  pt.xlabel('R')
  pt.ylabel('z')
  #pt.axis([-xmax,xmax,-xmax,xmax])
  pt.axis([-xmax,xmax,-xmax,xmax])
  
  ax = pt.gca()
  ax.set_aspect('equal', 'box')
  
  
  
  


pt.show()






