Program Listing for File BF.h

Program Listing for File BF.h#

Return to documentation for file (src/tfc/utils/BF.h)

#define _USE_MATH_DEFINES // Needed by Windows
#include <Python.h>
#include <float.h>
#include <iostream>
#include <math.h>
#include <vector>
#ifdef HAS_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#endif

#ifndef BF_H
#define BF_H

// BasisFunc
// **************************************************************************************************************************
class BasisFunc {

  public:
    double z0;

    double x0;

    double c;

    int *nC;

    int numC;

    int m;

    int identifier;

    PyObject *xlaCapsule;

#ifdef HAS_CUDA
    PyObject *xlaGpuCapsule;
#else
    const char *xlaGpuCapsule = "CUDA NOT FOUND, GPU NOT IMPLEMENTED.";
#endif

    static int nIdentifier;

    static std::vector<BasisFunc *> BasisFuncContainer;

  public:
    BasisFunc(double x0in, double xf, const int *nCin, int ncDim0, int min, double z0in = 0., double zf = DBL_MAX);

    virtual ~BasisFunc();

    // Prevent copying
    BasisFunc(const BasisFunc &) = delete;
    BasisFunc &operator=(const BasisFunc &) = delete;

    // Prevent moving
    BasisFunc(BasisFunc &&) = delete;
    BasisFunc &operator=(BasisFunc &&) = delete;

    virtual void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full);

    virtual void xla(void *out, void **in);

#ifdef HAS_CUDA
    void xlaGpu(CUstream stream, void **buffers, const char *opaque, size_t opaque_len);
#endif

  protected:
    BasisFunc() {};

    PyObject *GetXlaCapsule();

#ifdef HAS_CUDA
    PyObject *GetXlaCapsuleGpu();
#endif

  private:
    virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0;

    virtual void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) = 0;
};

// XLA related declarations:
// **********************************************************************************************************
typedef void (*xlaFnType)(void *, void **);

#ifdef HAS_CUDA
typedef void (*xlaGpuFnType)(CUstream, void **, const char *, size_t);

void xlaGpuWrapper(CUstream stream, void **buffers, const char *opaque, size_t opaque_len);
#endif

// CP:
// ********************************************************************************************************************************
class CP : virtual public BasisFunc {
  public:
    CP(double x0, double xf, const int *nCin, int ncDim0, int min)
        : BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {};

    virtual ~CP() {};

    // Prevent copying
    CP(const CP &) = delete;
    CP &operator=(const CP &) = delete;

    // Prevent moving
    CP(CP &&) = delete;
    CP &operator=(CP &&) = delete;

  protected:
    CP() {};

    void Hint(const int d, const double *x, const int nOut, double *dark);

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};

// LeP:
// ********************************************************************************************************************************
class LeP : virtual public BasisFunc {
  public:
    LeP(double x0, double xf, const int *nCin, int ncDim0, int min)
        : BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {};

    LeP() {};

    ~LeP() {};

  protected:
    void Hint(const int d, const double *x, const int nOut, double *dark);

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};

// LaP:
// ********************************************************************************************************************************
class LaP : public BasisFunc {
  public:
    LaP(double x0, double xf, const int *nCin, int ncDim0, int min)
        : BasisFunc(x0, xf, nCin, ncDim0, min) {};
    ~LaP() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark);

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};

// HoPpro:
// ********************************************************************************************************************************
class HoPpro : public BasisFunc {
  public:
    HoPpro(double x0, double xf, const int *nCin, int ncDim0, int min)
        : BasisFunc(x0, xf, nCin, ncDim0, min) {};
    ~HoPpro() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark);

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};

// HoPphy:
// ********************************************************************************************************************************
class HoPphy : public BasisFunc {
  public:
    HoPphy(double x0, double xf, const int *nCin, int ncDim0, int min)
        : BasisFunc(x0, xf, nCin, ncDim0, min) {};
    ~HoPphy() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark);

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};

// FS:
// ********************************************************************************************************************************
class FS : virtual public BasisFunc {
  public:
    FS(double x0, double xf, const int *nCin, int ncDim0, int min)
        : BasisFunc(x0, xf, nCin, ncDim0, min, -M_PI, M_PI) {};

    FS() {};

    ~FS() {};

  protected:
    void Hint(const int d, const double *x, const int nOut, double *dark);

    void RecurseDeriv(const int, int, const double *, const int, double *&, const int) {
        fprintf(stderr,
                "Warning, this function from FS should never be called. It seems it has been called by accident. "
                "Please check that this function was intended to be called.\n");
        printf("Warning, this function from FS should never be called. It seems it has been called by accident. Please "
               "check that this function was intended to be called.\n");
    };
};

// ELM base class:
// ********************************************************************************************************************************
class ELM : public BasisFunc {
  public:
    double *w;

    double *b;

    ELM(double x0, double xf, const int *nCin, int ncDim0, int min);

    virtual ~ELM();

    void getW(double **arrOut, int *nOut);

    void setW(const double *arrIn, int nIn);

    void getB(double **arrOut, int *nOut);

    void setB(const double *arrIn, int nIn);

  protected:
    virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0;

    void RecurseDeriv(const int, int, const double *, const int, double *&, const int) {
        fprintf(stderr,
                "Warning, this function from ELM should never be called. It seems it has been called by accident. "
                "Please check that this function was intended to be called.\n");
        printf("Warning, this function from ELM should never be called. It seems it has been called by accident. "
               "Please check that this function was intended to be called.\n");
    };
};

// ELM sigmoid:
// ********************************************************************************************************************************
class ELMSigmoid : public ELM {

  public:
    ELMSigmoid(double x0, double xf, const int *nCin, int ncDim0, int min)
        : ELM(x0, xf, nCin, ncDim0, min) {};

    ~ELMSigmoid() {};

  protected:
    void Hint(const int d, const double *x, const int nOut, double *dark);
};

// ELM ReLU:
// ********************************************************************************************************************************
class ELMReLU : public ELM {

  public:
    ELMReLU(double x0, double xf, const int *nCin, int ncDim0, int min)
        : ELM(x0, xf, nCin, ncDim0, min) {};

    ~ELMReLU() {};

  protected:
    void Hint(const int d, const double *x, const int nOut, double *dark);
};

// ELM Tanh:
// ********************************************************************************************************************************
class ELMTanh : public ELM {

  public:
    ELMTanh(double x0, double xf, const int *nCin, int ncDim0, int min)
        : ELM(x0, xf, nCin, ncDim0, min) {};

    ~ELMTanh() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark);
};

// ELM Sin:
// ********************************************************************************************************************************
class ELMSin : public ELM {

  public:
    ELMSin(double x0, double xf, const int *nCin, int ncDim0, int min)
        : ELM(x0, xf, nCin, ncDim0, min) {};

    ~ELMSin() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark);
};

// ELM Swish:
// ********************************************************************************************************************************
class ELMSwish : public ELM {

  public:
    ELMSwish(double x0, double xf, const int *nCin, int ncDim0, int min)
        : ELM(x0, xf, nCin, ncDim0, min) {};

    ~ELMSwish() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark);
};

// n-D Basis function base class:
// ***************************************************************************************************

class nBasisFunc : virtual public BasisFunc {

  public:
    double z0;

    double zf;

    double *c;

    double *x0;

    int dim;

    int numBasisFunc;

    int numBasisFuncFull;

  public:
    nBasisFunc(const double *x0in,
               int x0Dim0,
               const double *xf,
               int xfDim0,
               const int *nCin,
               int ncDim0,
               int ncDim1,
               int min,
               double z0in = 0.,
               double zfin = 0.);

    virtual ~nBasisFunc();

    void
    H(const double *x, int in, int xDim1, const int *d, int dDim0, int *nOut, int *mOut, double **F, const bool full);

    void xla(void *out, void **in) override;

    void getC(double **arrOut, int *nOut);

  protected:
    nBasisFunc() {};

  private:
    void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) override;

    void RecurseBasis(int dimCurr,
                      int *vec,
                      int &count,
                      const bool full,
                      const int in,
                      const int numBasis,
                      const double *T,
                      double *out);

    void NumBasisFunc(int dimCurr, int *vec, int &count, const bool full);

    virtual void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full);

    virtual void Hint(const int d, const double *x, const int nOut, double *dark) override = 0;

    virtual void
    RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) override = 0;
};

// n-D CP class:
// ******************************************************************************************************************
class nCP : public nBasisFunc, public CP {

  public:
    nCP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
        : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {};

    ~nCP() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark) { CP::Hint(d, x, nOut, dark); };

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
        CP::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
    };
};

// n-D LeP class:
// ******************************************************************************************************************
class nLeP : public nBasisFunc, public LeP {

  public:
    nLeP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
        : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {};

    ~nLeP() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark) { LeP::Hint(d, x, nOut, dark); };

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
        LeP::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
    };
};

// n-D FS class:
// ******************************************************************************************************************
class nFS : public nBasisFunc, public FS {

  public:
    nFS(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
        : nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -M_PI, M_PI) {};

    ~nFS() {};

  private:
    void Hint(const int d, const double *x, const int nOut, double *dark) { FS::Hint(d, x, nOut, dark); };

    void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
        FS::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
    };
};

// n-D ELM base class:
// *******************************************************************************************************************************************************
class nELM : public nBasisFunc {

  public:
    double z0;

    double zf;

    double *w;

    double *b;

    nELM(const double *x0in,
         int x0Dim0,
         const double *xf,
         int xfDim0,
         const int *nCin,
         int ncDim0,
         int min,
         double z0in = 0.,
         double zfin = 1.);

    virtual ~nELM();

    void setW(const double *arrIn, int dimIn, int nIn);

    void getW(int *dimOut, int *nOut, double **arrOut);

    void getB(double **arrOut, int *nOut);

    void setB(const double *arrIn, int nIn);

  private:
    void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full) override;

    virtual void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) = 0;

    void Hint(const int, const double *, const int, double *) override {
        fprintf(stderr,
                "Warning, this function from nELM should never be called. It seems it has been called by accident. "
                "Please check that this function was intended to be called.\n");
        printf("Warning, this function from nELM should never be called. It seems it has been called by accident. "
               "Please check that this function was intended to be called.\n");
    };

    void RecurseDeriv(const int, int, const double *, const int, double *&, const int) override {
        fprintf(stderr,
                "Warning, this function from nELM should never be called. It seems it has been called by accident. "
                "Please check that this function was intended to be called.\n");
        printf("Warning, this function from nELM should never be called. It seems it has been called by accident. "
               "Please check that this function was intended to be called.\n");
    };
};

// n-D ELM sigmoid class:
// *******************************************************************************************************************************************************
class nELMSigmoid : public nELM {

  public:
    nELMSigmoid(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
        : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};

    ~nELMSigmoid() {};

  private:
    void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};

// n-D ELM Tanh class:
// *******************************************************************************************************************************************************
class nELMTanh : public nELM {

  public:
    nELMTanh(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
        : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};

    ~nELMTanh() {};

  private:
    void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};

// n-D ELM Sin class:
// *******************************************************************************************************************************************************
class nELMSin : public nELM {

  public:
    nELMSin(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
        : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};

    ~nELMSin() {};

  private:
    void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};

// n-D ELM Swish class:
// *******************************************************************************************************************************************************
class nELMSwish : public nELM {

  public:
    nELMSwish(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
        : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};

    ~nELMSwish() {};

  private:
    void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};

// n-D ELM ReLU class:
// *******************************************************************************************************************************************************
class nELMReLU : public nELM {

  public:
    nELMReLU(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
        : nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};

    ~nELMReLU() {};

  private:
    void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};

#endif