
.. _program_listing_file_src_tfc_utils_BF.h:

Program Listing for File BF.h
=============================

|exhale_lsh| :ref:`Return to documentation for file <file_src_tfc_utils_BF.h>` (``src/tfc/utils/BF.h``)

.. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS

.. code-block:: cpp

   #define _USE_MATH_DEFINES // Needed by Windows
   #include <Python.h>
   #include <float.h>
   #include <iostream>
   #include <math.h>
   #include <vector>
   #ifdef HAS_CUDA
   #include <cuda.h>
   #include <cuda_runtime.h>
   #include <cuda_runtime_api.h>
   #endif
   
   #ifndef BF_H
   #define BF_H
   
   // BasisFunc
   // **************************************************************************************************************************
   class BasisFunc {
   
     public:
       double z0;
   
       double x0;
   
       double c;
   
       int *nC;
   
       int numC;
   
       int m;
   
       int identifier;
   
       PyObject *xlaCapsule;
   
   #ifdef HAS_CUDA
       PyObject *xlaGpuCapsule;
   #else
       const char *xlaGpuCapsule = "CUDA NOT FOUND, GPU NOT IMPLEMENTED.";
   #endif
   
       static int nIdentifier;
   
       static std::vector<BasisFunc *> BasisFuncContainer;
   
     public:
       BasisFunc(double x0in, double xf, const int *nCin, int ncDim0, int min, double z0in = 0., double zf = DBL_MAX);
   
       virtual ~BasisFunc();
   
       // Prevent copying
       BasisFunc(const BasisFunc &) = delete;
       BasisFunc &operator=(const BasisFunc &) = delete;
   
       // Prevent moving
       BasisFunc(BasisFunc &&) = delete;
       BasisFunc &operator=(BasisFunc &&) = delete;
   
       virtual void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full);
   
       virtual void xla(void *out, void **in);
   
   #ifdef HAS_CUDA
       void xlaGpu(CUstream stream, void **buffers, const char *opaque, size_t opaque_len);
   #endif
   
     protected:
       BasisFunc() {};
   
       PyObject *GetXlaCapsule();
   
   #ifdef HAS_CUDA
       PyObject *GetXlaCapsuleGpu();
   #endif
   
     private:
       virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0;
   
       virtual void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) = 0;
   };
   
   // XLA related declarations:
   // **********************************************************************************************************
   typedef void (*xlaFnType)(void *, void **);
   
   #ifdef HAS_CUDA
   typedef void (*xlaGpuFnType)(CUstream, void **, const char *, size_t);
   
   void xlaGpuWrapper(CUstream stream, void **buffers, const char *opaque, size_t opaque_len);
   #endif
   
   // CP:
   // ********************************************************************************************************************************
   class CP : virtual public BasisFunc {
     public:
       CP(double x0, double xf, const int *nCin, int ncDim0, int min)
           : BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {};
   
       virtual ~CP() {};
   
       // Prevent copying
       CP(const CP &) = delete;
       CP &operator=(const CP &) = delete;
   
       // Prevent moving
       CP(CP &&) = delete;
       CP &operator=(CP &&) = delete;
   
     protected:
       CP() {};
   
       void Hint(const int d, const double *x, const int nOut, double *dark);
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
   };
   
   // LeP:
   // ********************************************************************************************************************************
   class LeP : virtual public BasisFunc {
     public:
       LeP(double x0, double xf, const int *nCin, int ncDim0, int min)
           : BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {};
   
       LeP() {};
   
       ~LeP() {};
   
     protected:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
   };
   
   // LaP:
   // ********************************************************************************************************************************
   class LaP : public BasisFunc {
     public:
       LaP(double x0, double xf, const int *nCin, int ncDim0, int min)
           : BasisFunc(x0, xf, nCin, ncDim0, min) {};
       ~LaP() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
   };
   
   // HoPpro:
   // ********************************************************************************************************************************
   class HoPpro : public BasisFunc {
     public:
       HoPpro(double x0, double xf, const int *nCin, int ncDim0, int min)
           : BasisFunc(x0, xf, nCin, ncDim0, min) {};
       ~HoPpro() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
   };
   
   // HoPphy:
   // ********************************************************************************************************************************
   class HoPphy : public BasisFunc {
     public:
       HoPphy(double x0, double xf, const int *nCin, int ncDim0, int min)
           : BasisFunc(x0, xf, nCin, ncDim0, min) {};
       ~HoPphy() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
   };
   
   // FS:
   // ********************************************************************************************************************************
   class FS : virtual public BasisFunc {
     public:
       FS(double x0, double xf, const int *nCin, int ncDim0, int min)
           : BasisFunc(x0, xf, nCin, ncDim0, min, -M_PI, M_PI) {};
   
       FS() {};
   
       ~FS() {};
   
     protected:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   
       void RecurseDeriv(const int, int, const double *, const int, double *&, const int) {
           fprintf(stderr,
                   "Warning, this function from FS should never be called. It seems it has been called by accident. "
                   "Please check that this function was intended to be called.\n");
           printf("Warning, this function from FS should never be called. It seems it has been called by accident. Please "
                  "check that this function was intended to be called.\n");
       };
   };
   
   // ELM base class:
   // ********************************************************************************************************************************
   class ELM : public BasisFunc {
     public:
       double *w;
   
       double *b;
   
       ELM(double x0, double xf, const int *nCin, int ncDim0, int min);
   
       virtual ~ELM();
   
       void getW(double **arrOut, int *nOut);
   
       void setW(const double *arrIn, int nIn);
   
       void getB(double **arrOut, int *nOut);
   
       void setB(const double *arrIn, int nIn);
   
     protected:
       virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0;
   
       void RecurseDeriv(const int, int, const double *, const int, double *&, const int) {
           fprintf(stderr,
                   "Warning, this function from ELM should never be called. It seems it has been called by accident. "
                   "Please check that this function was intended to be called.\n");
           printf("Warning, this function from ELM should never be called. It seems it has been called by accident. "
                  "Please check that this function was intended to be called.\n");
       };
   };
   
   // ELM sigmoid:
   // ********************************************************************************************************************************
   class ELMSigmoid : public ELM {
   
     public:
       ELMSigmoid(double x0, double xf, const int *nCin, int ncDim0, int min)
           : ELM(x0, xf, nCin, ncDim0, min) {};
   
       ~ELMSigmoid() {};
   
     protected:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   };
   
   // ELM ReLU:
   // ********************************************************************************************************************************
   class ELMReLU : public ELM {
   
     public:
       ELMReLU(double x0, double xf, const int *nCin, int ncDim0, int min)
           : ELM(x0, xf, nCin, ncDim0, min) {};
   
       ~ELMReLU() {};
   
     protected:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   };
   
   // ELM Tanh:
   // ********************************************************************************************************************************
   class ELMTanh : public ELM {
   
     public:
       ELMTanh(double x0, double xf, const int *nCin, int ncDim0, int min)
           : ELM(x0, xf, nCin, ncDim0, min) {};
   
       ~ELMTanh() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   };
   
   // ELM Sin:
   // ********************************************************************************************************************************
   class ELMSin : public ELM {
   
     public:
       ELMSin(double x0, double xf, const int *nCin, int ncDim0, int min)
           : ELM(x0, xf, nCin, ncDim0, min) {};
   
       ~ELMSin() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   };
   
   // ELM Swish:
   // ********************************************************************************************************************************
   class ELMSwish : public ELM {
   
     public:
       ELMSwish(double x0, double xf, const int *nCin, int ncDim0, int min)
           : ELM(x0, xf, nCin, ncDim0, min) {};
   
       ~ELMSwish() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark);
   };
   
   // n-D Basis function base class:
   // ***************************************************************************************************
   
   class nBasisFunc : virtual public BasisFunc {
   
     public:
       double z0;
   
       double zf;
   
       double *c;
   
       double *x0;
   
       int dim;
   
       int numBasisFunc;
   
       int numBasisFuncFull;
   
     public:
       nBasisFunc(const double *x0in,
                  int x0Dim0,
                  const double *xf,
                  int xfDim0,
                  const int *nCin,
                  int ncDim0,
                  int ncDim1,
                  int min,
                  double z0in = 0.,
                  double zfin = 0.);
   
       virtual ~nBasisFunc();
   
       void
       H(const double *x, int in, int xDim1, const int *d, int dDim0, int *nOut, int *mOut, double **F, const bool full);
   
       void xla(void *out, void **in) override;
   
       void getC(double **arrOut, int *nOut);
   
     protected:
       nBasisFunc() {};
   
     private:
       void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) override;
   
       void RecurseBasis(int dimCurr,
                         int *vec,
                         int &count,
                         const bool full,
                         const int in,
                         const int numBasis,
                         const double *T,
                         double *out);
   
       void NumBasisFunc(int dimCurr, int *vec, int &count, const bool full);
   
       virtual void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full);
   
       virtual void Hint(const int d, const double *x, const int nOut, double *dark) override = 0;
   
       virtual void
       RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) override = 0;
   };
   
   // n-D CP class:
   // ******************************************************************************************************************
   class nCP : public nBasisFunc, public CP {
   
     public:
       nCP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
           : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {};
   
       ~nCP() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark) { CP::Hint(d, x, nOut, dark); };
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
           CP::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
       };
   };
   
   // n-D LeP class:
   // ******************************************************************************************************************
   class nLeP : public nBasisFunc, public LeP {
   
     public:
       nLeP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
           : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {};
   
       ~nLeP() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark) { LeP::Hint(d, x, nOut, dark); };
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
           LeP::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
       };
   };
   
   // n-D FS class:
   // ******************************************************************************************************************
   class nFS : public nBasisFunc, public FS {
   
     public:
       nFS(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
           : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -M_PI, M_PI) {};
   
       ~nFS() {};
   
     private:
       void Hint(const int d, const double *x, const int nOut, double *dark) { FS::Hint(d, x, nOut, dark); };
   
       void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
           FS::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
       };
   };
   
   // n-D ELM base class:
   // *******************************************************************************************************************************************************
   class nELM : public nBasisFunc {
   
     public:
       double z0;
   
       double zf;
   
       double *w;
   
       double *b;
   
       nELM(const double *x0in,
            int x0Dim0,
            const double *xf,
            int xfDim0,
            const int *nCin,
            int ncDim0,
            int min,
            double z0in = 0.,
            double zfin = 1.);
   
       virtual ~nELM();
   
       void setW(const double *arrIn, int dimIn, int nIn);
   
       void getW(int *dimOut, int *nOut, double **arrOut);
   
       void getB(double **arrOut, int *nOut);
   
       void setB(const double *arrIn, int nIn);
   
     private:
       void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full) override;
   
       virtual void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) = 0;
   
       void Hint(const int, const double *, const int, double *) override {
           fprintf(stderr,
                   "Warning, this function from nELM should never be called. It seems it has been called by accident. "
                   "Please check that this function was intended to be called.\n");
           printf("Warning, this function from nELM should never be called. It seems it has been called by accident. "
                  "Please check that this function was intended to be called.\n");
       };
   
       void RecurseDeriv(const int, int, const double *, const int, double *&, const int) override {
           fprintf(stderr,
                   "Warning, this function from nELM should never be called. It seems it has been called by accident. "
                   "Please check that this function was intended to be called.\n");
           printf("Warning, this function from nELM should never be called. It seems it has been called by accident. "
                  "Please check that this function was intended to be called.\n");
       };
   };
   
   // n-D ELM sigmoid class:
   // *******************************************************************************************************************************************************
   class nELMSigmoid : public nELM {
   
     public:
       nELMSigmoid(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
           : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
   
       ~nELMSigmoid() {};
   
     private:
       void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
   };
   
   // n-D ELM Tanh class:
   // *******************************************************************************************************************************************************
   class nELMTanh : public nELM {
   
     public:
       nELMTanh(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
           : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
   
       ~nELMTanh() {};
   
     private:
       void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
   };
   
   // n-D ELM Sin class:
   // *******************************************************************************************************************************************************
   class nELMSin : public nELM {
   
     public:
       nELMSin(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
           : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
   
       ~nELMSin() {};
   
     private:
       void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
   };
   
   // n-D ELM Swish class:
   // *******************************************************************************************************************************************************
   class nELMSwish : public nELM {
   
     public:
       nELMSwish(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
           : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
   
       ~nELMSwish() {};
   
     private:
       void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
   };
   
   // n-D ELM ReLU class:
   // *******************************************************************************************************************************************************
   class nELMReLU : public nELM {
   
     public:
       nELMReLU(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
           : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
   
       ~nELMReLU() {};
   
     private:
       void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
   };
   
   #endif
