#include "Eigen/Dense"
#include <iostream>
#include <cassert>
#include <limits>
#include <cmath>

#include "dco.hpp"

template<typename TA, typename TP, int N>
TA f(const Eigen::Matrix<TA,N,1> &p, const Eigen::Matrix<TP,N,1> &x) { 
  TA px=0;
  for (auto i=0;i<p.size();i++) px+=p(i)*x(i);
  return pow(px,2); 
}

template<typename TA, typename TP,int M, int N>
TA E(const Eigen::Matrix<TA,N,1> &p, const Eigen::Matrix<TP,M,N> &x, const Eigen::Matrix<TP,M,1> &y) {
  int m=y.size();
  TA o=0;
  for (auto i=0;i<m;i++) {
    Eigen::Matrix<TP,N,1> xr=x.row(i);
    o+=pow(f(p,xr)-y(i),2);
  }
  return o;
}

template<typename TA, typename TP, int M, int N>
Eigen::Matrix<TA,N,1> dEdp(const Eigen::Matrix<TA,N,1> &p_v, const Eigen::Matrix<TP,M,N> &x, const Eigen::Matrix<TP,M,1> &y) {
  using DCO_M=typename dco::ga1s<TA>;
  using DCO_T=typename DCO_M::type;
  using DCO_TT=typename DCO_M::tape_t;
  int n=p_v.size();
  Eigen::Matrix<TA,N,1> dodp(n); 
  Eigen::Matrix<DCO_T,N,1> p(n); DCO_T o=0;
  DCO_M::global_tape=DCO_TT::create();
  for (auto i=0;i<n;i++) {
    p(i)=p_v(i);
    DCO_M::global_tape->register_variable(p(i));
  }
  o=E(p,x,y);
  dco::derivative(o)=1;
  DCO_M::global_tape->interpret_adjoint();
  for (auto i=0;i<n;i++) dodp(i)=dco::derivative(p(i));
  DCO_TT::remove(DCO_M::global_tape);
  return dodp;
}

template<typename T, int M, int N>
Eigen::Matrix<T,N,N> ddEdpp(const Eigen::Matrix<T,N,1> &p_v, const Eigen::Matrix<T,M,N> &x, const Eigen::Matrix<T,M,1> &y) {
  using DCO_T=typename dco::gt1s<T>::type;
  int n=p_v.size();
  Eigen::Matrix<T,N,N> ddodpp(n,n);
  Eigen::Matrix<DCO_T,N,1> p(n); 
  for (auto i=0;i<n;i++) p(i)=p_v(i);
  for (auto i=0;i<n;i++) {
    dco::derivative(p(i))=1;
    Eigen::Matrix<DCO_T,N,1> dodp=dEdp(p,x,y);
    for (auto j=0;j<n;j++) ddodpp(j,i)=dco::derivative(dodp(j));
    dco::derivative(p(i))=0;
  }
  return ddodpp;
}

template<typename T, int M, int N>
void Newton(Eigen::Matrix<T,N,1>& p, const Eigen::Matrix<T,M,N> &x, const Eigen::Matrix<T,M,1> &y, const T& eps) {
  Eigen::Matrix<T,N,1> dodp=dEdp(p,x,y);
  do {
    Eigen::Matrix<T,N,N> ddodpp=ddEdpp(p,x,y);
    p+=ddodpp.llt().solve(-dodp);
    dodp=dEdp(p,x,y);
  } while (dodp.norm()>eps);
}

int main(int argc, char* argv[]) {
  assert(argc==3); int m=std::stoi(argv[1]), n=std::stoi(argv[2]);
  using T=double;
  using MT=Eigen::Matrix<T,Eigen::Dynamic,Eigen::Dynamic>;
  using VT=Eigen::Matrix<T,Eigen::Dynamic,1>;
  MT x=MT::Random(m,n); x=x.cwiseProduct(x);
  VT y=VT::Random(m), p=VT::Random(n);
  Newton(p,x,y,1e-5);
  std::cout << "p^T=" << p.transpose() << std::endl;
  std::cout << "E'^T=" << dEdp(p,x,y).transpose() << std::endl;
  std::cout << "E''=" << std::endl << ddEdpp(p,x,y) << std::endl;
  Eigen::LLT<Eigen::Matrix<T,Eigen::Dynamic,Eigen::Dynamic>> spd_test(ddEdpp(p,x,y));
  if (spd_test.info()!=Eigen::NumericalIssue)
    std::cout << "Hessian is spd." << std::endl;
  else
    std::cout << "Hessian appears to be not spd!" << std::endl;
  return 0;
}
