Adjoint Code Design Patterns  v0.1
U.Naumann
ACModule.hpp
1 #ifndef ACMODULE_INCLUDED
2 #define ACMODULE_INCLUDED
3 
4 #include<stack>
5 #include<vector>
6 using namespace std;
7 
8 namespace ACDesignPatterns {
9 
14 template<typename T>
15 class VarRefs {
16  vector<T*> v;
17  public:
18  size_t register_var(T& z) {
19  v.push_back(&z); return v.size()-1;
20  }
21  void clear() {
22  v.clear();
23  }
24  T& operator[](size_t i) {
25  return *v[i];
26  }
27  size_t size() { return v.size(); }
28 };
29 
33 template<typename T>
34 class ACModule {
35  protected:
36  stack<vector<T> > argument_checkpoint;
37  public:
38  virtual ~ACModule() {};
39  VarRefs<T> x,ax,y,ay;
40  void register_input(T& v, T& a) {
41  x.register_var(v); ax.register_var(a);
42  }
43  void register_input(VarRefs<T>& v, VarRefs<T>& a) {
44  x=v; ax=a;
45  }
46  void register_output(T& v, T& a) {
47  y.register_var(v); ay.register_var(a);
48  }
49  void register_output(VarRefs<T>& v, VarRefs<T>& a) {
50  y=v; ay=a;
51  }
52  void register_inoutput(T& xv, T& xa, T& yv, T& ya) {
53  x.register_var(xv); ax.register_var(xa);
54  y.register_var(yv); ay.register_var(ya);
55  }
56  void register_inoutput(VarRefs<T>& xv, VarRefs<T>& xa, VarRefs<T>& yv, VarRefs<T>& ya) {
57  x=xv; ax=xa; y=yv; ay=ya;
58  }
59  void reset_inoutput() { x.clear(); ax.clear(); y.clear(); ay.clear(); }
60  virtual void push_arguments() {
61  vector<T> a(x.size());
62  for (size_t i=0;i<x.size();i++) a[i]=x[i];
63  argument_checkpoint.push(a);
64  }
65  virtual void read_arguments() {
66  vector<T> a(x.size()); a=argument_checkpoint.top();
67  for (size_t i=0;i<x.size();i++) x[i]=a[i];
68  }
69  virtual void pop_arguments() {
70  argument_checkpoint.pop();
71  }
72  virtual void evaluate_primal()=0;
73  virtual void evaluate_joint_primal() {
74  push_arguments();
75  evaluate_primal();
76  }
77  virtual void evaluate_split_primal()=0;
78  virtual void evaluate_split_adjoint()=0;
79  virtual void evaluate_joint_adjoint() {
80  read_arguments(); pop_arguments();
81  evaluate_split_primal();
82  evaluate_split_adjoint();
83  }
84 };
85 
86 }
87 
88 #endif