#ifndef __TOOLS__
#define __TOOLS__


#include <iostream>
#include <vector>
#include <map>
#include <stdexcept>

#include "trigger/util.hh"
#include "geometry.hh"


/**
 * \file
 * Tools package.
 */

namespace TOOLS {

  using namespace UTIL;
  using namespace GEOMETRY;

  /** 
   * equal operator
   * \param  a   vector3d
   * \param  b   vector3d
   * \return     true id a == b; else false
   */
  bool operator==(const vector3d& a, const vector3d& b)
  {
    return (a.x() == b.x() && 
	    a.y() == b.y() && 
	    a.z() == b.z());
  }

  /** 
   * equal operator
   * \param  a   point
   * \param  b   point
   * \return     true id a == b; else false
   */
  bool operator==(const point& a, const point& b)
  {
    return (a.x() == b.x() && 
	    a.y() == b.y() && 
	    a.z() == b.z() && 
	    a.t() == b.t());
  }

  /** 
   * equal operator
   * \param  a   axis
   * \param  b   axis
   * \return     true id a == b; else false
   */
  bool operator==(const axis& a, const axis& b)
  {
    return (a.a()     == b.a()     && 
	    a.b()     == b.b()     && 
	    a.theta() == b.theta() && 
	    a.phi()   == b.phi());
  }

  /** 
   * equal operator
   * \param  a   track
   * \param  b   track
   * \return     true id a == b; else false
   */
  bool operator==(const track& a, const track& b)
  {
    return (a.a()     == b.a()     &&
	    a.b()     == b.b()     && 
	    a.theta() == b.theta() &&
	    a.phi()   == b.phi()   && 
	    a.t0()    == b.t0());
  }


  /**
   * dot product
   * \param  a vector3d
   * \param  b vector3d
   * \return   dot product
   */
  static inline double vdot(const vector3d& a, const vector3d& b)
  {
    return (a.x() * b.x() +
            a.y() * b.y() +
            a.z() * b.z());
  }

  /**
   * dot product
   * \param  a point
   * \param  b point
   * \return   dot product
   */
  static inline double vdot(const point& a, const point& b)
  {
    return (a.x() * b.x() +
            a.y() * b.y() +
            a.z() * b.z() +
	    a.t() * b.t());
  }

  /**
   * dot product
   * \param  a axis
   * \param  b axis
   * \return   dot product
   */
  static inline double vdot(const axis& a, const axis& b)
  {
    return (a.a()     * b.a()     +
            a.b()     * b.b()     +
	    a.theta() * b.theta() +
	    a.phi()   * b.phi());
  }

  /**
   * dot product
   * \param  a track
   * \param  b track
   * \return   dot product
   */
  static inline double vdot(const track& a, const track& b)
  {
    return (a.a()     * b.a()     +
            a.b()     * b.b()     +
	    a.theta() * b.theta() +
	    a.phi()   * b.phi()   +
            a.t0()    * b.t0());
  }
  
  
  /**
   * Distance between track1Z and position.
   * \param  ta track
   * \param  o  position
   * \return    arrival time [ns]
   */
  static inline double distance(const track1Z& ta, const position& o)
  {
    t_pos dx = o.x() - ta.x();
    t_pos dy = o.y() - ta.y();
    
    return sqrt(dx*dx + dy*dy);
  }

  /**
   * Distance between track1Z and position.
   * \param  ta track
   * \param  o  position
   * \return    arrival time [ns]
   */
  template<class T>
  static inline double distance(const track1Z& ta, const T& o)
  {
    t_pos dx = o.x() - ta.x();
    t_pos dy = o.y() - ta.y();
    
    return sqrt(dx*dx + dy*dy);
  }

  /**
   * Distance between axis and position.
   * \param  a_ axis
   * \param  o_ position
   * \return    distance [m]
   */
  static inline double distance(const axis& a_, const position& o_)
  {
    rotation R(a_);
    position o(o_);
    
    o.rotate(R);
    
    double da = o.x() - a_.a();
    double db = o.y() - a_.b();
    
    return sqrt(da*da + db*db);	
  }

  /**
   * Distance between axis and position.
   * \param  a_ axis
   * \param  o_ position
   * \return    distance [m]
   */
  template<class T>
  static inline double distance(const axis& a_, const T& o_)
  {
    rotation R(a_);
    position o(o_.x(),o_.y(),o_.z());
    
    o.rotate(R);
    
    double da = o.x() - a_.a();
    double db = o.y() - a_.b();
    
    return sqrt(da*da + db*db);	
  }
  
  
  /**
   * Hit time prediction.
   * \param  ta track
   * \param  o  position
   * \return    arrival time [ns]
   */
  static inline double time(const track1Z& ta, const position& o)
  {
    t_pos dx = o.x() - ta.x();
    t_pos dy = o.y() - ta.y();
    t_pos dz = o.z() - ta.z();
    
    t_pos r  = sqrt(dx*dx + dy*dy);
    
    return ta.t() + (dz + r * getTanThetaC()) * C_INVERSE;
  }

  /**
   * Hit time prediction.
   * \param  ta track
   * \param  o  position
   * \return    arrival time [ns]
   */
  template<class T>
  static inline double time(const track1Z& ta, const T& o)
  {
    t_pos dx = o.x() - ta.x();
    t_pos dy = o.y() - ta.y();
    t_pos dz = o.z() - ta.z();
    
    t_pos r  = sqrt(dx*dx + dy*dy);
    
    return ta.t() + (dz + r * getTanThetaC()) * C_INVERSE;
  }


  /**
   * Hit time residual (hit time - prediction).
   * \param  ta track
   * \param  o_ hit object
   * \return    residual [ns]
   */
  static inline double residual(const track1Z& ta, const point& o_)
  {
    t_pos dx = o_.x() - ta.x();
    t_pos dy = o_.y() - ta.y();
    t_pos dz = o_.z() - ta.z();
    
    t_pos r  = sqrt(dx*dx + dy*dy);
    
    return o_.t() - ta.t() - (dz + r * getTanThetaC()) * C_INVERSE;
  }

  /**
   * Hit time residual (hit time - prediction).
   * \param  ta track
   * \param  o_ hit object
   * \return    residual [ns]
   */
  template<class T>
  static inline double residual(const track1Z& ta, const T& o_)
  {
    t_pos dx = o_.x() - ta.x();
    t_pos dy = o_.y() - ta.y();
    t_pos dz = o_.z() - ta.z();
    
    t_pos r  = sqrt(dx*dx + dy*dy);
    
    return o_.t() - ta.t() - (dz + r * getTanThetaC()) * C_INVERSE;
  }

  /**
   * Hit time residual (hit time - prediction).
   * \param  ta track
   * \param  o_ hit object
   * \return    residual [ns]
   */
  static inline double residual(const track& ta, const point& o_)
  {
    rotation R(ta);
    position o(o_);
    
    o.rotate(R);
    
    double da = o.x() - ta.a();
    double db = o.y() - ta.b();
    
    double r  = sqrt(da*da + db*db);	
    double t  = ta.t0()  +  (o.z()  +  r * getTanThetaC()) * C_INVERSE;
    
    return o_.t() - t;
  }

  /**
   * Hit time residual (hit time - prediction).
   * \param  ta track
   * \param  o_ hit object
   * \return    residual [ns]
   */
  template<class T>
  static inline double residual(const track& ta, const T& o_)
  {
    rotation R(ta);
    position o(o_.x(),o_.y(),o_.z());
    
    o.rotate(R);
    
    double da = o.x() - a();
    double db = o.y() - b();
    
    double r  = sqrt(da*da + db*db);	
    double t  = ta.t0()  +  (o.z()  +  r * getTanThetaC()) * C_INVERSE;
    
    return o_.t() - t;
  }


  /**
   * Intercept between two tracks.
   * \param  ta track
   * \param  tb track
   * \return    intercept
   */
  static inline position intercept(const track& ta, const track& tb)
  {
    position pos((const position&) ta - (const position&) tb);
    position dir(tb.dx(), tb.dy(), tb.dz());
    
    rotation R((const direction&) ta);

    // rotate 

    pos.rotate(R);
    dir.rotate(R);

    // position at minimal distance of approach

    t_pos alpha = pos.x() * dir.x()  +  pos.y() * dir.y();
    t_pos beta  = dir.x() * dir.x()  +  dir.y() * dir.y();
    t_pos path  = (beta != 0 ? -alpha/beta : 0);
    
    // extrapolate

    pos += dir * path;

    // rotate back

    return pos.rotate_back(R);
  }


  /**
   * Time difference between two tracks at minimal distance of approach.
   * \param  ta track [m,ns]
   * \param  tb track [m,ns]
   * \return    time [ns]
   */
  static inline t_time time(const track& ta, const track& tb)
  {
    position pos(ta.x()  - tb.x(), 
		 ta.y()  - tb.y(), 
		 ta.z()  - tb.z());

    position dir(ta.dx() + tb.dx(), 
		 ta.dy() + tb.dy(), 
		 ta.dz() + tb.dz());

    t_time tab = ta.t0() - tb.t0();

    tab -= dot(pos,dir) * C_INVERSE / (1.0 + dot((direction&) ta, (direction&) tb));
    
    return tab;
  }


  /**
   * Function interface for chi2 minimisation.
   *        _
   * chi2 = > rho(u_j)
   *        j
   *
   * where u_j is the normalised residual of data point j, i.e. u_j = residual_j / sigma_j.
   *
   * phi(u) is the first derivative of rho, i.e. psi(u') = d/du rho(u)|u=u'
   *
   * ups(u) is the second derivative of rho, i.e. ups(u') = d2/(du)2 rho(u)|u=u'
   */
  class FunctionInterface {
  public:
    virtual ~FunctionInterface() {}
    virtual double rho(const double& u) const = 0;  //!< function value
    virtual double psi(const double& u) const = 0;  //!< first derivative
    virtual double ups(const double& u) const = 0;  //!< second derivative
  };

  /**
   * Normal distribution
   */
  class Chi2 : public FunctionInterface {
  public:
    virtual ~Chi2() {}
    virtual double rho(const double& u) const { return 0.5*u*u; }
    virtual double psi(const double& u) const { return u; }
    virtual double ups(const double& u) const { return 1.0; }
  };

  /**
   * Lorentzian distribution
   */
  class M_Estimate : public FunctionInterface {
  public:
    virtual ~M_Estimate() {}
    virtual double rho(const double& u) const { return log(1.0 + 0.5*u*u); }
    virtual double psi(const double& u) const { return u / (1.0 + 0.5*u*u); }
    virtual double ups(const double& u) const { return (1.0 - 0.5*u*u) / ((1.0 + 0.5*u*u) * (1.0 + 0.5*u*u)); }
  };


  /**
   * Calculation of M-estimate of data.
   *
   * \param ta         track
   * \param begin_     begin of hit data
   * \param end_       end of hit data
   * \param fcn        chi2 function
   * \return           M-estimate
   */
  template<class Track_, class hitIterator>
  double m_estimate(const Track_& ta,
		    const hitIterator& begin_,
		    const hitIterator& end_,
		    const FunctionInterface& fcn)
  {
    double val = 0.0;
    double u;
    
    for (hitIterator i = begin_; i != end_; ++i) {
      u    = residual(ta,*i) / i->sigma();
      val += fcn.rho(u);
    }
    
    return val;
  }
  

  /**
   * Residual sorter: smallest number of standard deviations first
   */
  template<class Track_>
  class ResidualSorter {
  public:
    /**
     * Constructor
     * \param  track_    track
     */
    ResidualSorter(const Track_& track_) :
      track(track_)
    {}
    
    /**
     * Compare hit time residuals
     * \param  a hit object
     * \param  b hit object
     * \return   true if time residual of first hit is smaller; else false
     */
    template<class T>
    bool operator()(const T& a, const T& b) const
    { 
      return (fabs(residual(track,a)) * b.sigma() < fabs(residual(track,b)) * a.sigma());
    }
    
  private:
    const Track_& track;
  };


  /**
   * Residual selection using sigma of hit
   */
  template<class Track_>
  class ResidualSelector {
  public:
    /**
     * Constructor
     * \param  track_    track
     * \param  stdev_    number of standard deviations
     */
    ResidualSelector(const Track_& track_,
		     const double& stdev_) :
      track(track_),
      stdev(stdev_)
    {}
    
    /**
     * Test hit time residual
     * \param  o hit object
     * \return   true if residual <= sigma x number of standard deviations; else false
     */
    template<class T>
    bool operator()(const T& o) const 
    { 
      return (fabs(residual(track,o)) < o.sigma() * stdev); 
    }
    
  private:
    const Track_& track;
    const double  stdev;
  };
}

#endif

