Adjoint Code Design Patterns  v0.1
U.Naumann
ACEnsemblePathwise.hpp
1 #ifndef ACENSEMBLE_PATHWISE_INCLUDED
2 #define ACENSEMBLE_PATHWISE_INCLUDED
3 
4 #include<iostream>
5 using namespace std;
6 
7 #include "ACEnsemble.hpp"
8 
9 namespace ACDesignPatterns {
10 
15 template<typename T>
16 class ACEnsemblePathwise : public ACEnsemble<T> {
17  protected:
18  using ACEnsemble<T>::path;
19  public:
24  using ACModule<T>::x;
25  using ACModule<T>::ax;
26  using ACModule<T>::y;
27  using ACModule<T>::ay;
28  ACEnsemblePathwise(size_t npaths) : ACEnsemble<T>(npaths) {};
29  void evaluate_primal() {
30  T xs=x[0],ys=0;
31  for (size_t t=0;t<npaths;t++) {
32  x[0]=xs;
33  path->reset_inoutput();
34  path->register_inoutput(x,ax,y,ay);
35  path->evaluate_primal();
36  ys+=y[0];
37  }
38  ys/=npaths;
39  y[0]=ys;
40  }
41  void evaluate_split_primal() {
42  push_arguments();
43  evaluate_primal();
44  }
45  void evaluate_split_adjoint() {
46  T axs=0,ays=0;
47  ays=ay[0]; ay[0]=0;
48  ays/=npaths;
49  for (size_t t=0;t<npaths;t++) {
50  read_arguments();
51  ay[0]=ays;
52  path->reset_inoutput();
53  path->register_inoutput(x,ax,y,ay);
54  path->evaluate_split_primal();
55  path->evaluate_split_adjoint();
56  axs+=ax[0]; ax[0]=0;
57  }
58  ax[0]=axs; axs=0;
59  pop_arguments();
60  }
61 };
62 
63 }
64 
65 #endif