Program Listing for File BF.h#
↰ Return to documentation for file (src/tfc/utils/BF.h)
#define _USE_MATH_DEFINES // Needed by Windows
#include <Python.h>
#include <float.h>
#include <iostream>
#include <math.h>
#include <vector>
#ifdef HAS_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#endif
#ifndef BF_H
#define BF_H
// BasisFunc
// **************************************************************************************************************************
class BasisFunc {
public:
double z0;
double x0;
double c;
int *nC;
int numC;
int m;
int identifier;
PyObject *xlaCapsule;
#ifdef HAS_CUDA
PyObject *xlaGpuCapsule;
#else
const char *xlaGpuCapsule = "CUDA NOT FOUND, GPU NOT IMPLEMENTED.";
#endif
static int nIdentifier;
static std::vector<BasisFunc *> BasisFuncContainer;
public:
BasisFunc(double x0in, double xf, const int *nCin, int ncDim0, int min, double z0in = 0., double zf = DBL_MAX);
virtual ~BasisFunc();
// Prevent copying
BasisFunc(const BasisFunc &) = delete;
BasisFunc &operator=(const BasisFunc &) = delete;
// Prevent moving
BasisFunc(BasisFunc &&) = delete;
BasisFunc &operator=(BasisFunc &&) = delete;
virtual void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full);
virtual void xla(void *out, void **in);
#ifdef HAS_CUDA
void xlaGpu(CUstream stream, void **buffers, const char *opaque, size_t opaque_len);
#endif
protected:
BasisFunc() {};
PyObject *GetXlaCapsule();
#ifdef HAS_CUDA
PyObject *GetXlaCapsuleGpu();
#endif
private:
virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0;
virtual void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) = 0;
};
// XLA related declarations:
// **********************************************************************************************************
typedef void (*xlaFnType)(void *, void **);
#ifdef HAS_CUDA
typedef void (*xlaGpuFnType)(CUstream, void **, const char *, size_t);
void xlaGpuWrapper(CUstream stream, void **buffers, const char *opaque, size_t opaque_len);
#endif
// CP:
// ********************************************************************************************************************************
class CP : virtual public BasisFunc {
public:
CP(double x0, double xf, const int *nCin, int ncDim0, int min)
: BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {};
virtual ~CP() {};
// Prevent copying
CP(const CP &) = delete;
CP &operator=(const CP &) = delete;
// Prevent moving
CP(CP &&) = delete;
CP &operator=(CP &&) = delete;
protected:
CP() {};
void Hint(const int d, const double *x, const int nOut, double *dark);
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};
// LeP:
// ********************************************************************************************************************************
class LeP : virtual public BasisFunc {
public:
LeP(double x0, double xf, const int *nCin, int ncDim0, int min)
: BasisFunc(x0, xf, nCin, ncDim0, min, -1., 1.) {};
LeP() {};
~LeP() {};
protected:
void Hint(const int d, const double *x, const int nOut, double *dark);
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};
// LaP:
// ********************************************************************************************************************************
class LaP : public BasisFunc {
public:
LaP(double x0, double xf, const int *nCin, int ncDim0, int min)
: BasisFunc(x0, xf, nCin, ncDim0, min) {};
~LaP() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark);
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};
// HoPpro:
// ********************************************************************************************************************************
class HoPpro : public BasisFunc {
public:
HoPpro(double x0, double xf, const int *nCin, int ncDim0, int min)
: BasisFunc(x0, xf, nCin, ncDim0, min) {};
~HoPpro() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark);
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};
// HoPphy:
// ********************************************************************************************************************************
class HoPphy : public BasisFunc {
public:
HoPphy(double x0, double xf, const int *nCin, int ncDim0, int min)
: BasisFunc(x0, xf, nCin, ncDim0, min) {};
~HoPphy() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark);
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut);
};
// FS:
// ********************************************************************************************************************************
class FS : virtual public BasisFunc {
public:
FS(double x0, double xf, const int *nCin, int ncDim0, int min)
: BasisFunc(x0, xf, nCin, ncDim0, min, -M_PI, M_PI) {};
FS() {};
~FS() {};
protected:
void Hint(const int d, const double *x, const int nOut, double *dark);
void RecurseDeriv(const int, int, const double *, const int, double *&, const int) {
fprintf(stderr,
"Warning, this function from FS should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
printf("Warning, this function from FS should never be called. It seems it has been called by accident. Please "
"check that this function was intended to be called.\n");
};
};
// ELM base class:
// ********************************************************************************************************************************
class ELM : public BasisFunc {
public:
double *w;
double *b;
ELM(double x0, double xf, const int *nCin, int ncDim0, int min);
virtual ~ELM();
void getW(double **arrOut, int *nOut);
void setW(const double *arrIn, int nIn);
void getB(double **arrOut, int *nOut);
void setB(const double *arrIn, int nIn);
protected:
virtual void Hint(const int d, const double *x, const int nOut, double *dark) = 0;
void RecurseDeriv(const int, int, const double *, const int, double *&, const int) {
fprintf(stderr,
"Warning, this function from ELM should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
printf("Warning, this function from ELM should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
};
};
// ELM sigmoid:
// ********************************************************************************************************************************
class ELMSigmoid : public ELM {
public:
ELMSigmoid(double x0, double xf, const int *nCin, int ncDim0, int min)
: ELM(x0, xf, nCin, ncDim0, min) {};
~ELMSigmoid() {};
protected:
void Hint(const int d, const double *x, const int nOut, double *dark);
};
// ELM ReLU:
// ********************************************************************************************************************************
class ELMReLU : public ELM {
public:
ELMReLU(double x0, double xf, const int *nCin, int ncDim0, int min)
: ELM(x0, xf, nCin, ncDim0, min) {};
~ELMReLU() {};
protected:
void Hint(const int d, const double *x, const int nOut, double *dark);
};
// ELM Tanh:
// ********************************************************************************************************************************
class ELMTanh : public ELM {
public:
ELMTanh(double x0, double xf, const int *nCin, int ncDim0, int min)
: ELM(x0, xf, nCin, ncDim0, min) {};
~ELMTanh() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark);
};
// ELM Sin:
// ********************************************************************************************************************************
class ELMSin : public ELM {
public:
ELMSin(double x0, double xf, const int *nCin, int ncDim0, int min)
: ELM(x0, xf, nCin, ncDim0, min) {};
~ELMSin() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark);
};
// ELM Swish:
// ********************************************************************************************************************************
class ELMSwish : public ELM {
public:
ELMSwish(double x0, double xf, const int *nCin, int ncDim0, int min)
: ELM(x0, xf, nCin, ncDim0, min) {};
~ELMSwish() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark);
};
// n-D Basis function base class:
// ***************************************************************************************************
class nBasisFunc : virtual public BasisFunc {
public:
double z0;
double zf;
double *c;
double *x0;
int dim;
int numBasisFunc;
int numBasisFuncFull;
public:
nBasisFunc(const double *x0in,
int x0Dim0,
const double *xf,
int xfDim0,
const int *nCin,
int ncDim0,
int ncDim1,
int min,
double z0in = 0.,
double zfin = 0.);
virtual ~nBasisFunc();
void
H(const double *x, int in, int xDim1, const int *d, int dDim0, int *nOut, int *mOut, double **F, const bool full);
void xla(void *out, void **in) override;
void getC(double **arrOut, int *nOut);
protected:
nBasisFunc() {};
private:
void H(const double *x, int n, const int d, int *nOut, int *mOut, double **F, bool full) override;
void RecurseBasis(int dimCurr,
int *vec,
int &count,
const bool full,
const int in,
const int numBasis,
const double *T,
double *out);
void NumBasisFunc(int dimCurr, int *vec, int &count, const bool full);
virtual void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full);
virtual void Hint(const int d, const double *x, const int nOut, double *dark) override = 0;
virtual void
RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) override = 0;
};
// n-D CP class:
// ******************************************************************************************************************
class nCP : public nBasisFunc, public CP {
public:
nCP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
: nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {};
~nCP() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark) { CP::Hint(d, x, nOut, dark); };
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
CP::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
};
};
// n-D LeP class:
// ******************************************************************************************************************
class nLeP : public nBasisFunc, public LeP {
public:
nLeP(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
: nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -1., 1.) {};
~nLeP() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark) { LeP::Hint(d, x, nOut, dark); };
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
LeP::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
};
};
// n-D FS class:
// ******************************************************************************************************************
class nFS : public nBasisFunc, public FS {
public:
nFS(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int ncDim1, int min)
: nBasisFunc(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, ncDim1, min, -M_PI, M_PI) {};
~nFS() {};
private:
void Hint(const int d, const double *x, const int nOut, double *dark) { FS::Hint(d, x, nOut, dark); };
void RecurseDeriv(const int d, int dCurr, const double *x, const int nOut, double *&F, const int mOut) {
FS::RecurseDeriv(d, dCurr, x, nOut, F, mOut);
};
};
// n-D ELM base class:
// *******************************************************************************************************************************************************
class nELM : public nBasisFunc {
public:
double z0;
double zf;
double *w;
double *b;
nELM(const double *x0in,
int x0Dim0,
const double *xf,
int xfDim0,
const int *nCin,
int ncDim0,
int min,
double z0in = 0.,
double zfin = 1.);
virtual ~nELM();
void setW(const double *arrIn, int dimIn, int nIn);
void getW(int *dimOut, int *nOut, double **arrOut);
void getB(double **arrOut, int *nOut);
void setB(const double *arrIn, int nIn);
private:
void nHint(const double *x, int in, const int *d, int dDim0, int numBasis, double *&F, const bool full) override;
virtual void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) = 0;
void Hint(const int, const double *, const int, double *) override {
fprintf(stderr,
"Warning, this function from nELM should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
printf("Warning, this function from nELM should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
};
void RecurseDeriv(const int, int, const double *, const int, double *&, const int) override {
fprintf(stderr,
"Warning, this function from nELM should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
printf("Warning, this function from nELM should never be called. It seems it has been called by accident. "
"Please check that this function was intended to be called.\n");
};
};
// n-D ELM sigmoid class:
// *******************************************************************************************************************************************************
class nELMSigmoid : public nELM {
public:
nELMSigmoid(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
: nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
~nELMSigmoid() {};
private:
void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};
// n-D ELM Tanh class:
// *******************************************************************************************************************************************************
class nELMTanh : public nELM {
public:
nELMTanh(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
: nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
~nELMTanh() {};
private:
void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};
// n-D ELM Sin class:
// *******************************************************************************************************************************************************
class nELMSin : public nELM {
public:
nELMSin(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
: nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
~nELMSin() {};
private:
void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};
// n-D ELM Swish class:
// *******************************************************************************************************************************************************
class nELMSwish : public nELM {
public:
nELMSwish(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
: nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
~nELMSwish() {};
private:
void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};
// n-D ELM ReLU class:
// *******************************************************************************************************************************************************
class nELMReLU : public nELM {
public:
nELMReLU(const double *x0in, int x0Dim0, const double *xf, int xfDim0, const int *nCin, int ncDim0, int min)
: nELM(x0in, x0Dim0, xf, xfDim0, nCin, ncDim0, min) {};
~nELMReLU() {};
private:
void nElmHint(const int *d, int dDim0, const double *x, const int in, double *F) override;
};
#endif