Actual source code: cupmcontext.hpp

  1: #if !defined(PETSCDEVICECONTEXTCUPM_HPP)
  2: #define PETSCDEVICECONTEXTCUPM_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupminterface.hpp>

  7: #if !defined(PETSC_HAVE_CXX_DIALECT_CXX11)
  8: #error PetscDeviceContext backends for CUDA and HIP requires C++11
  9: #endif

 11: namespace Petsc {

 13: // Forward declare
 14: template <CUPMDeviceKind T> class CUPMContext;

 16: template <CUPMDeviceKind T>
 17: class CUPMContext : CUPMInterface<T>
 18: {
 19: public:
 20:   PETSC_INHERIT_CUPM_INTERFACE_TYPEDEFS_USING(cupmInterface_t,T)

 22:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 23:   // header, but since we are using the power of templates it must be declared part of
 24:   // this class to have easy access the same typedefs. Technically one can make a
 25:   // templated struct outside the class but it's more code for the same result.
 26:   struct PetscDeviceContext_IMPLS
 27:   {
 28:     cupmStream_t       stream;
 29:     cupmEvent_t        event;
 30:     cupmBlasHandle_t   blas;
 31:     cupmSolverHandle_t solver;
 32:   };

 34: private:
 35:   static cupmBlasHandle_t   _blashandle;
 36:   static cupmSolverHandle_t _solverhandle;

 38:   PETSC_NODISCARD static PetscErrorCode __finalizeBLASHandle() noexcept
 39:   {

 43:     cupmInterface_t::DestroyHandle(_blashandle);
 44:     return(0);
 45:   }

 47:   PETSC_NODISCARD static PetscErrorCode __finalizeSOLVERHandle() noexcept
 48:   {

 52:     cupmInterface_t::DestroyHandle(_solverhandle);
 53:     return(0);
 54:   }

 56:   PETSC_NODISCARD static PetscErrorCode __setupHandles(PetscDeviceContext_IMPLS *dci) noexcept
 57:   {
 58:     PetscErrorCode  ierr;

 61:     if (!_blashandle) {
 62:       cupmInterface_t::InitializeHandle(_blashandle);
 63:       PetscRegisterFinalize(__finalizeBLASHandle);
 64:     }
 65:     if (!_solverhandle) {
 66:       cupmInterface_t::InitializeHandle(_solverhandle);
 67:       PetscRegisterFinalize(__finalizeSOLVERHandle);
 68:     }
 69:     cupmInterface_t::SetHandleStream(_blashandle,dci->stream);
 70:     cupmInterface_t::SetHandleStream(_solverhandle,dci->stream);
 71:     dci->blas   = _blashandle;
 72:     dci->solver = _solverhandle;
 73:     return(0);
 74:   }

 76: public:
 77:   const struct _DeviceContextOps ops {destroy,changeStreamType,setUp,query,waitForContext,synchronize};

 79:   // default constructor
 80:   constexpr CUPMContext() noexcept = default;

 82:   // All of these functions MUST be static in order to be callable from C, otherwise they
 83:   // get the implicit 'this' pointer tacked on
 84:   PETSC_NODISCARD static PetscErrorCode destroy(PetscDeviceContext) noexcept;
 85:   PETSC_NODISCARD static PetscErrorCode changeStreamType(PetscDeviceContext,PetscStreamType) noexcept;
 86:   PETSC_NODISCARD static PetscErrorCode setUp(PetscDeviceContext) noexcept;
 87:   PETSC_NODISCARD static PetscErrorCode query(PetscDeviceContext,PetscBool*) noexcept;
 88:   PETSC_NODISCARD static PetscErrorCode waitForContext(PetscDeviceContext,PetscDeviceContext) noexcept;
 89:   PETSC_NODISCARD static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
 90: };

 92: #define IMPLS_RCAST_(obj_) static_cast<PetscDeviceContext_IMPLS*>((obj_)->data)

 94: template <CUPMDeviceKind T>
 95: inline PetscErrorCode CUPMContext<T>::destroy(PetscDeviceContext dctx) noexcept
 96: {
 97:   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
 98:   cupmError_t              cerr;
 99:   PetscErrorCode           ierr;

102:   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
103:   if (dci->event)  {cerr = cupmEventDestroy(dci->event);CHKERRCUPM(cerr);}
104:   PetscFree(dctx->data);
105:   return(0);
106: }

108: template <CUPMDeviceKind T>
109: inline PetscErrorCode CUPMContext<T>::changeStreamType(PetscDeviceContext dctx, PetscStreamType stype) noexcept
110: {
111:   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);

114:   if (dci->stream) {
115:     cupmError_t cerr;

117:     cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);
118:     dci->stream = nullptr;
119:   }
120:   // set these to null so they aren't usable until setup is called again
121:   dci->blas   = nullptr;
122:   dci->solver = nullptr;
123:   return(0);
124: }

126: template <CUPMDeviceKind T>
127: inline PetscErrorCode CUPMContext<T>::setUp(PetscDeviceContext dctx) noexcept
128: {
129:   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
130:   PetscErrorCode           ierr;
131:   cupmError_t              cerr;

134:   if (dci->stream) {cerr = cupmStreamDestroy(dci->stream);CHKERRCUPM(cerr);}
135:   switch (dctx->streamType) {
136:   case PETSC_STREAM_GLOBAL_BLOCKING:
137:     // don't create a stream for global blocking
138:     dci->stream = nullptr;
139:     break;
140:   case PETSC_STREAM_DEFAULT_BLOCKING:
141:     cerr = cupmStreamCreate(&dci->stream);CHKERRCUPM(cerr);
142:     break;
143:   case PETSC_STREAM_GLOBAL_NONBLOCKING:
144:     cerr = cupmStreamCreateWithFlags(&dci->stream,cupmStreamNonBlocking);CHKERRCUPM(cerr);
145:     break;
146:   default:
147:     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_CORRUPT,"Invalid PetscStreamType %d",dctx->streamType);
148:     break;
149:   }
150:   if (!dci->event) {cerr = cupmEventCreate(&dci->event);CHKERRCUPM(cerr);}
151:   __setupHandles(dci);
152:   return(0);
153: }

155: template <CUPMDeviceKind T>
156: inline PetscErrorCode CUPMContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
157: {
158:   cupmError_t cerr;

161:   cerr = cupmStreamQuery(IMPLS_RCAST_(dctx)->stream);
162:   if (cerr == cupmSuccess)
163:     *idle = PETSC_TRUE;
164:   else if (cerr == cupmErrorNotReady) {
165:     *idle = PETSC_FALSE;
166:   } else {
167:     // somethings gone wrong
168:     CHKERRCUPM(cerr);
169:   }
170:   return(0);
171: }

173: template <CUPMDeviceKind T>
174: inline PetscErrorCode CUPMContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
175: {
176:   PetscDeviceContext_IMPLS *dcia = IMPLS_RCAST_(dctxa);
177:   PetscDeviceContext_IMPLS *dcib = IMPLS_RCAST_(dctxb);
178:   cupmError_t               cerr;

181:   cerr = cupmEventRecord(dcib->event,dcib->stream);CHKERRCUPM(cerr);
182:   cerr = cupmStreamWaitEvent(dcia->stream,dcib->event,0);CHKERRCUPM(cerr);
183:   return(0);
184: }

186: template <CUPMDeviceKind T>
187: inline PetscErrorCode CUPMContext<T>::synchronize(PetscDeviceContext dctx) noexcept
188: {
189:   PetscDeviceContext_IMPLS *dci = IMPLS_RCAST_(dctx);
190:   cupmError_t               cerr;

193:   // in case anything was queued on the event
194:   cerr = cupmStreamWaitEvent(dci->stream,dci->event,0);CHKERRCUPM(cerr);
195:   cerr = cupmStreamSynchronize(dci->stream);CHKERRCUPM(cerr);
196:   return(0);
197: }

199: // initialize the static member variables
200: template <CUPMDeviceKind T>
201: typename CUPMContext<T>::cupmBlasHandle_t   CUPMContext<T>::_blashandle   = nullptr;

203: template <CUPMDeviceKind T>
204: typename CUPMContext<T>::cupmSolverHandle_t CUPMContext<T>::_solverhandle = nullptr;

206: // shorten this one up a bit
207: using CUPMContextCuda = CUPMContext<CUPMDeviceKind::CUDA>;
208: using CUPMContextHip  = CUPMContext<CUPMDeviceKind::HIP>;

210: // make sure these doesn't leak out
211: #undef CHKERRCUPM
212: #undef IMPLS_RCAST_

214: } // namespace Petsc

216: // shorthand for what is an EXTREMELY long name
217: #define PetscDeviceContext_(impls_) Petsc::CUPMContext<Petsc::CUPMDeviceKind::impls_>::PetscDeviceContext_IMPLS

219: // shorthand for casting dctx->data to the appropriate object to access the handles
220: #define PDC_IMPLS_RCAST(impls_,obj_) reinterpret_cast<PetscDeviceContext_(impls_) *>((obj_)->data)

222: #endif /* PETSCDEVICECONTEXTCUDA_HPP */