This is an automated email from the ASF dual-hosted git repository. markd pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 889d57570380862190ffbb4e9f02faa7b5827005 Author: Mark Dokter <[email protected]> AuthorDate: Mon Feb 22 17:00:42 2021 +0100 [SYSTEMDS-2853] Refactor spoof cuda runtime operations This commit cleans up the initial implementation of the runtime parts of cuda codegen. It not only improves the structure of the code, but also avoids unnecessary conditionals (which are moved to the compile phase). * The SpoofCUDA class is now split into several subclasses of SpoofOperator and additionally implements the interface class SpoofCUDAOperator. * The native parts make more use of templating, which makes the code more compact and also avoids unnecessary conditionals. * More string comparisons are replaced by array indexing. * The return value from native methods is now used as a status indicator and the scalar result (which is what the return value has been used for before in full_agg ops) is now directly downloaded from the GPU buffer. (fallback to Java ops is on the ToDo list) * Single/double precision decision is now set at startup according to configuration (similarly to LibMatrixCUDA's cudaSupportFunctions class member) --- src/main/cuda/CMakeLists.txt | 20 +- src/main/cuda/headers/Matrix.h | 22 +- src/main/cuda/headers/agg_ops.cuh | 8 + .../cuda/headers/intellisense_cuda_intrinsics.h | 215 ----------- src/main/cuda/headers/reduction.cuh | 41 +- src/main/cuda/kernels/reduction.cu | 33 +- src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp | 83 ++--- src/main/cuda/spoof-launcher/SpoofCUDAContext.h | 414 +++------------------ src/main/cuda/spoof-launcher/SpoofCellwise.h | 203 ++++++++++ src/main/cuda/spoof-launcher/SpoofOperator.h | 104 ++++++ src/main/cuda/spoof-launcher/SpoofRowwise.h | 77 ++++ src/main/cuda/spoof-launcher/host_utils.h | 4 +- src/main/cuda/spoof-launcher/jni_bridge.cpp | 305 +++++++-------- src/main/cuda/spoof-launcher/jni_bridge.h | 60 +-- src/main/java/org/apache/sysds/common/Types.java | 12 +- .../apache/sysds/hops/codegen/SpoofCompiler.java | 43 ++- .../apache/sysds/hops/codegen/cplan/CNodeCell.java | 3 +- .../apache/sysds/runtime/codegen/CodegenUtils.java | 18 +- .../apache/sysds/runtime/codegen/SpoofCUDA.java | 150 -------- .../sysds/runtime/codegen/SpoofCUDACellwise.java | 156 ++++++++ .../sysds/runtime/codegen/SpoofCUDAOperator.java | 173 +++++++++ .../sysds/runtime/codegen/SpoofCUDARowwise.java | 123 ++++++ .../sysds/runtime/codegen/SpoofCellwise.java | 14 +- .../sysds/runtime/codegen/SpoofMultiAggregate.java | 5 + .../sysds/runtime/codegen/SpoofOperator.java | 11 +- .../sysds/runtime/codegen/SpoofOuterProduct.java | 5 + .../apache/sysds/runtime/codegen/SpoofRowwise.java | 80 ++-- .../instructions/gpu/SpoofCUDAInstruction.java | 131 ++++--- .../instructions/gpu/context/GPUObject.java | 2 +- .../test/functions/codegen/CellwiseTmplTest.java | 2 +- .../test/functions/codegen/RowAggTmplTest.java | 2 +- 31 files changed, 1376 insertions(+), 1143 deletions(-) diff --git a/src/main/cuda/CMakeLists.txt b/src/main/cuda/CMakeLists.txt index cfa72c4..ee13087 100644 --- a/src/main/cuda/CMakeLists.txt +++ b/src/main/cuda/CMakeLists.txt @@ -90,16 +90,28 @@ endif() set(SPOOF_HEADERS spoof-launcher/jni_bridge.h - spoof-launcher/SpoofCUDAContext.h headers/Matrix.h headers/TempStorage.cuh) + spoof-launcher/SpoofCUDAContext.h + headers/Matrix.h + spoof-launcher/SpoofOperator.h + spoof-launcher/SpoofRowwise.h + spoof-launcher/SpoofCellwise.h) + set(SPOOF_SOURCES spoof-launcher/jni_bridge.cpp - spoof-launcher/SpoofCUDAContext.cpp) + spoof-launcher/SpoofCUDAContext.cpp + ) + set(SPOOF_CUDA_HEADERS headers/agg_ops.cuh headers/reduction.cuh headers/spoof_utils.cuh - headers/utils.cuh headers/operators.cuh headers/Matrix.h headers/vector_write.cuh headers/vector_add.cuh) - #headers/intellisense_cuda_intrinsics.h + headers/TempStorage.cuh + headers/utils.cuh + headers/operators.cuh + headers/Matrix.h + headers/vector_write.cuh + headers/vector_add.cuh) + set(SPOOF_TEMPLATES spoof/cellwise.cu spoof/rowwise.cu) diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h index a11fc33..808590b 100644 --- a/src/main/cuda/headers/Matrix.h +++ b/src/main/cuda/headers/Matrix.h @@ -22,17 +22,23 @@ #define SYSTEMDS_MATRIX_H using uint32_t = unsigned int; +using int32_t = int; template <typename T> struct Matrix { - T* data; - uint32_t* row_ptr; - uint32_t* col_idx; - + int32_t nnz; uint32_t rows; uint32_t cols; - uint32_t nnz; + + uint32_t* row_ptr; + uint32_t* col_idx; + T* data; + typedef T value_type; + + explicit Matrix(size_t* jvals) : nnz(jvals[0]), rows(jvals[1]), cols(jvals[2]), + row_ptr(reinterpret_cast<uint32_t*>(jvals[3])), + col_idx(reinterpret_cast<uint32_t*>((jvals[4]))), data(reinterpret_cast<T*>(jvals[5])) {} }; //#ifdef __CUDACC_RTC__ @@ -42,7 +48,7 @@ template<typename T> uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) { upper -= 1; while(lower <= (upper-1)) { - uint32_t idx = lower + upper >> 1; + uint32_t idx = (lower + upper) >> 1; uint32_t vi = values[idx]; if (vi < val) lower = idx + 1; @@ -140,6 +146,10 @@ public: __device__ uint32_t* indexes() { return _mat->row_ptr; } + + __device__ bool hasData() { + return _mat->data != nullptr; + } private: __device__ uint32_t len_dense() { return _mat->rows * _mat->cols; diff --git a/src/main/cuda/headers/agg_ops.cuh b/src/main/cuda/headers/agg_ops.cuh index aec7d31..ff5a734 100644 --- a/src/main/cuda/headers/agg_ops.cuh +++ b/src/main/cuda/headers/agg_ops.cuh @@ -204,6 +204,14 @@ struct MaxOp<float> { __device__ __forceinline__ float operator()(float a, float b) const { return fmaxf(a, b); } + + __device__ __forceinline__ static float exec(float const & a, float const & b) { + return fmaxf(a, b); + } + + __device__ __forceinline__ static float init() { + return MaxNeutralElement<float>::get(); + } }; /** diff --git a/src/main/cuda/headers/intellisense_cuda_intrinsics.h b/src/main/cuda/headers/intellisense_cuda_intrinsics.h deleted file mode 100644 index d1b45d4..0000000 --- a/src/main/cuda/headers/intellisense_cuda_intrinsics.h +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef INTELLISENSE_CUDA_INTRINSICS_H -#define INTELLISENSE_CUDA_INTRINSICS_H -#pragma once - -#if 1 - #define __INTELLISENSE__ -#endif - -#ifdef __INTELLISENSE__ - -#include <cmath> -#include <cstdio> -#include <corecrt_math.h> - -#include <cuda.h> -#include <cuda_runtime.h> -#include <device_launch_parameters.h> - -#define CUDART_INF 0.0 -#define CUDART_INF_F 0.0f -#define CUDART_NAN -#define CUDART_NAN_F - -// Reverse the bit order of a 32 bit unsigned integer. -__device__ unsigned int __brev(unsigned int x) {}; - -//Reverse the bit order of a 64 bit unsigned integer. -__device__ unsigned long long int __brevll(unsigned long long int x) {}; - - -//Return selected bytes from two 32 bit unsigned integers. -__device__ unsigned int __byte_perm(unsigned int x, unsigned int y, unsigned int s) {}; - - -//Return the number of consecutive high - order zero bits in a 32 bit integer. -__device__ int __clz(int x) {}; - - -//Count the number of consecutive high - order zero bits in a 64 bit integer. -__device__ int __clzll(long long int x) {}; - - -//Find the position of the least significant bit set to 1 in a 32 bit integer. -__device__ int __ffs(int x) {}; - - -//Find the position of the least significant bit set to 1 in a 64 bit integer.Concatenate hi : lo, shift left by shift & 31 bits, return the most significant 32 bits. -__device__ int __ffsll(long long int x) {}; - - -//Concatenate hi : lo, shift left by shift & 31 bits, return the most significant 32 bits. -__device__ unsigned int __funnelshift_l(unsigned int lo, unsigned int hi, unsigned int shift) {}; - - -//Concatenate hi : lo, shift left by min(shift, 32) bits, return the most significant 32 bits. -__device__ unsigned int __funnelshift_lc(unsigned int lo, unsigned int hi, unsigned int shift) {}; - - -//Concatenate hi : lo, shift right by shift & 31 bits, return the least significant 32 bits. -__device__ unsigned int __funnelshift_r(unsigned int lo, unsigned int hi, unsigned int shift) {}; - - -//Concatenate hi : lo, shift right by min(shift, 32) bits, return the least significant 32 bits. -__device__ unsigned int __funnelshift_rc(unsigned int lo, unsigned int hi, unsigned int shift) {}; - - -//Compute average of signed input arguments, avoiding overflow in the intermediate sum. -__device__ int __hadd(int, int) {}; - - -//Calculate the least significant 32 bits of the product of the least significant 24 bits of two integers. -__device__ int __mul24(int x, int y) {}; - - -//Calculate the most significant 64 bits of the product of the two 64 bit integers. -__device__ long long int __mul64hi(long long int x, long long int y) {}; - - -//Calculate the most significant 32 bits of the product of the two 32 bit integers. -__device__ int __mulhi(int x, int y) {}; - - -//Count the number of bits that are set to 1 in a 32 bit integer. -__device__ int __popc(unsigned int x) {}; - - -//Count the number of bits that are set to 1 in a 64 bit integer. -__device__ int __popcll(unsigned long long int x) {}; - - -//Compute rounded average of signed input arguments, avoiding overflow in the intermediate sum. -__device__ int __rhadd(int, int) {}; - - -//Calculate | x − y | +z, the sum of absolute difference. -__device__ unsigned int __sad(int x, int y, unsigned int z) {}; - - -//Compute average of unsigned input arguments, avoiding overflow in the intermediate sum. -__device__ unsigned int __uhadd(unsigned int, unsigned int) {}; - - -//Calculate the least significant 32 bits of the product of the least significant 24 bits of two unsigned integers. -__device__ unsigned int __umul24(unsigned int x, unsigned int y) {}; - - -//Calculate the most significant 64 bits of the product of the two 64 unsigned bit integers. -__device__ unsigned long long int __umul64hi(unsigned long long int x, unsigned long long int y) {}; - - -//Calculate the most significant 32 bits of the product of the two 32 bit unsigned integers. -__device__ unsigned int __umulhi(unsigned int x, unsigned int y) {}; - - -//Compute rounded average of unsigned input arguments, avoiding overflow in the intermediate sum. -__device__ unsigned int __urhadd(unsigned int, unsigned int) {}; - - -//Calculate | x − y | +z, the sum of absolute difference. -__device__ unsigned int __usad(unsigned int x, unsigned int y, unsigned int z) {}; - -////////////////////////////////////////////////////// -//atomic functions - -int atomicAdd(int* address, int val) {}; -unsigned int atomicAdd(unsigned int* address, unsigned int val) {}; -unsigned long long int atomicAdd(unsigned long long int* address, unsigned long long int val) {}; -float atomicAdd(float* address, float val) {}; -double atomicAdd(double* address, double val) {}; - -typedef int __half2; -typedef short __half; -__half2 atomicAdd(__half2* address, __half2 val) {}; -__half atomicAdd(__half* address, __half val) {}; - -int atomicSub(int* address, int val) {}; -unsigned int atomicSub(unsigned int* address, unsigned int val) {}; - -int atomicExch(int* address, int val) {}; -unsigned int atomicExch(unsigned int* address, unsigned int val) {}; -unsigned long long int atomicExch(unsigned long long int* address, unsigned long long int val) {}; -float atomicExch(float* address, float val) {}; - -int atomicMin(int* address, int val) {}; -unsigned int atomicMin(unsigned int* address, unsigned int val) {}; -unsigned long long int atomicMin(unsigned long long int* address, unsigned long long int val) {}; - -int atomicMax(int* address, int val) {}; -unsigned int atomicMax(unsigned int* address, unsigned int val) {}; -unsigned long long int atomicMax(unsigned long long int* address, unsigned long long int val) {}; - -unsigned int atomicInc(unsigned int* address, unsigned int val) {}; - -unsigned int atomicDec(unsigned int* address, unsigned int val) {}; - -int atomicCAS(int* address, int compare, int val) {}; -unsigned int atomicCAS(unsigned int* address, unsigned int compare, unsigned int val) {}; -unsigned long long int atomicCAS(unsigned long long int* address, - unsigned long long int compare, - unsigned long long int val) {}; -unsigned short int atomicCAS(unsigned short int* address, - unsigned short int compare, - unsigned short int val) {}; - -int atomicAnd(int* address, int val) {}; -unsigned int atomicAnd(unsigned int* address, - unsigned int val) {}; -unsigned long long int atomicAnd(unsigned long long int* address, - unsigned long long int val) {}; - -int atomicOr(int* address, int val) {}; -unsigned int atomicOr(unsigned int* address, - unsigned int val) {}; -unsigned long long int atomicOr(unsigned long long int* address, - unsigned long long int val) {}; - -int atomicXor(int* address, int val) {}; -unsigned int atomicXor(unsigned int* address, unsigned int val) {}; -unsigned long long int atomicXor(unsigned long long int* address, unsigned long long int val) {}; - -template <typename T> -unsigned int __match_any_sync(unsigned mask, T value) {}; -template <typename T> -unsigned int __match_all_sync(unsigned mask, T value, int* pred) {}; - -template <typename T> -T __shfl_sync(unsigned mask, T var, int srcLane, int width = warpSize) {}; -template <typename T> -T __shfl_up_sync(unsigned mask, T var, unsigned int delta, int width = warpSize) {}; -template <typename T> -T __shfl_down_sync(unsigned mask, T var, unsigned int delta, int width = warpSize) {}; -template <typename T> -T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize) {}; - -#endif // __INTELLISENSE__ - -#endif // INTELLISENSE_CUDA_INTRINSICS_H \ No newline at end of file diff --git a/src/main/cuda/headers/reduction.cuh b/src/main/cuda/headers/reduction.cuh index 568eb12..88cc45c 100644 --- a/src/main/cuda/headers/reduction.cuh +++ b/src/main/cuda/headers/reduction.cuh @@ -116,8 +116,8 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32_t if (blockDim.x >= 64) { smem[tid] = v = reduction_op(v, smem[tid + 32]); } - if(tid<12) - printf("bid=%d tid=%d reduction result: %3.1f\n", blockIdx.x, tid, sdata[tid]); +// if(tid<12) +// printf("bid=%d tid=%d reduction result: %3.1f\n", blockIdx.x, tid, sdata[tid]); if (blockDim.x >= 32) { smem[tid] = v = reduction_op(v, smem[tid + 16]); @@ -149,7 +149,7 @@ __device__ void FULL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32_t // write result for this block to global mem if (tid == 0) { // if(gridDim.x < 10) - printf("blockIdx.x=%d reduction result: %3.1f\n", blockIdx.x, sdata[0]); +// printf("blockIdx.x=%d reduction result: %3.1f\n", blockIdx.x, sdata[0]); out->val(0, blockIdx.x) = sdata[0]; } } @@ -317,7 +317,8 @@ __device__ void NO_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32_t N uint32_t last_idx = min(first_idx + static_cast<uint32_t>(VT), N); #pragma unroll for(auto i = first_idx; i < last_idx; i++) { - T result = spoof_op(in->vals(0)[i], i, i / in->cols(), i % in->cols()); + T a = in->hasData() ? in->vals(0)[i] : 0; + T result = spoof_op(a, i, i / in->cols(), i % in->cols()); out->vals(0)[i] = result; //if(i < 4) // printf("tid=%d in=%4.3f res=%4.3f out=%4.3f r=%d\n", i, in->vals(0)[i], result, out->vals(0)[i], i/in->cols()); @@ -334,28 +335,24 @@ __device__ void NO_AGG_SPARSE(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uin // uint32_t cix = in->col_idxs(0)[gtid]; uint32_t row_start = in->pos(rix); uint32_t row_len = in->row_len(rix); -//// if(cix == 0) { -// //if(row_start == gtid) { -// //if(threadIdx.x == 0) { -// if(rix < in->rows()) { -// if(rix > 895 && rix < 905) -// printf("gtid=%d in->indexes()[rix=%d]=%d rowlen=%d row_start=%d cix=%d\n", gtid, rix, in->indexes()[rix], in->row_len(rix), row_start, cix); -//// out->indexes()[gtid] = in->indexes()[gtid]; -// out->indexes()[rix] = in->indexes()[rix]; -// } - while(tid < row_len) { - - uint32_t* aix = in->col_idxs(rix); - uint32_t cix = aix[tid]; + if(in->hasData()) { + uint32_t *aix = in->col_idxs(rix); + uint32_t cix = aix[tid]; // T result = spoof_op(in->val(rix, cix), rix*in->rows()+cix, rix, cix); - T result = spoof_op(in->val(row_start+tid), rix*in->rows()+cix, rix, cix); - out->set(row_start+tid, cix, result); - - //if(rix > 899 && rix < 903 && cix==0) - // printf("rix=%d row_start=%d tid=%d result=%4.3f\n", rix, row_start, tid, result); + T result = spoof_op(in->val(row_start + tid), rix * in->rows() + cix, rix, cix); + out->set(row_start + tid, cix, result); +// if(rix > 899 && rix < 903 && cix==0) +// if(rix < 10 && cix==0) +// printf("rix=%d row_start=%d tid=%d result=%4.3f\n", rix, row_start, tid, result); + } + else { + uint32_t cix = tid; + T result = spoof_op(0, rix * in->rows() + cix, rix, cix); + out->set(row_start + tid, cix, result); + } tid+=blockDim.x; diff --git a/src/main/cuda/kernels/reduction.cu b/src/main/cuda/kernels/reduction.cu index 3a11f77..b05a54e 100644 --- a/src/main/cuda/kernels/reduction.cu +++ b/src/main/cuda/kernels/reduction.cu @@ -22,10 +22,6 @@ #include "reduction.cuh" #include "Matrix.h" -using uint = unsigned int; -#include <cuda_runtime.h> -#ifdef __CUDACC__ - /** * Do a summation over all elements of an array/matrix * @param g_idata input data stored in device memory (of size n) @@ -39,12 +35,18 @@ __device__ void reduce_sum(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32 FULL_AGG<T, SumOp<T>, IdentityOp<T>>(in, out, n, (T) 0.0, agg_op, spoof_op); } +extern "C" __global__ void reduce_sum_f(Matrix<float>* in, Matrix<float>* out, uint32_t n) { + MatrixAccessor<float> _in(in); + MatrixAccessor<float> _out(out); + reduce_sum(&_in, &_out, n); +} + extern "C" __global__ void reduce_sum_d(Matrix<double>* in, Matrix<double>* out, uint32_t n) { MatrixAccessor<double> _in(in); MatrixAccessor<double> _out(out); reduce_sum(&_in, &_out, n); } -#endif + //extern "C" __global__ void reduce_sum_f(float *g_idata, float *g_odata, uint n) { // reduce_sum(g_idata, g_odata, n); //} @@ -107,16 +109,18 @@ __device__ void reduce_max(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32 FULL_AGG<T, MaxOp<T>, IdentityOp<T>>(in, out, n, -MAX<T>(), agg_op, spoof_op); } +extern "C" __global__ void reduce_max_f(Matrix<float>* in, Matrix<float>* out, uint32_t n) { + MatrixAccessor<float> _in(in); + MatrixAccessor<float> _out(out); + reduce_max(&_in, &_out, n); +} + extern "C" __global__ void reduce_max_d(Matrix<double>* in, Matrix<double>* out, uint32_t n) { MatrixAccessor<double> _in(in); MatrixAccessor<double> _out(out); reduce_max(&_in, &_out, n); } -//extern "C" __global__ void reduce_max_f(float *g_idata, float *g_odata, uint n) { -// reduce_max(g_idata, g_odata, n); -//} - /** * Do a max over all rows of a matrix * @param g_idata input matrix stored in device memory (of size rows * cols) @@ -175,17 +179,18 @@ __device__ void reduce_min(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32 FULL_AGG<T, MinOp<T>, IdentityOp<T>>(in, out, n, MAX<T>(), agg_op, spoof_op); } +extern "C" __global__ void reduce_min_f(Matrix<float>* in, Matrix<float>* out, uint32_t n) { + MatrixAccessor<float> _in(in); + MatrixAccessor<float> _out(out); + reduce_min(&_in, &_out, n); +} + extern "C" __global__ void reduce_min_d(Matrix<double>* in, Matrix<double>* out, uint32_t n) { MatrixAccessor<double> _in(in); MatrixAccessor<double> _out(out); reduce_min(&_in, &_out, n); } -//extern "C" __global__ void reduce_min_f(float *g_idata, float *g_odata, uint n) { -// reduce_min(g_idata, g_odata, n); -//} - - /** * Do a min over all rows of a matrix * @param g_idata input matrix stored in device memory (of size rows * cols) diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp index 3e023d5..29dc46b 100644 --- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp +++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp @@ -46,7 +46,7 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t device_id, const char* resourc std::stringstream s1, s2; s1 << "-I" << resource_path << "/cuda/headers"; s2 << "-I" << resource_path << "/cuda/spoof"; - auto *ctx = new SpoofCUDAContext(resource_path, {s1.str(), s2.str(), cuda_include_path}); + auto ctx = new SpoofCUDAContext(resource_path,{s1.str(), s2.str(), cuda_include_path}); // cuda device is handled by jCuda atm //cudaSetDevice(device_id); //cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync); @@ -56,79 +56,78 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t device_id, const char* resourc CUfunction func; - // ToDo: implement a more scalable solution for these imports - // SUM + CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_f")); + ctx->reduction_kernels_f.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG, SpoofOperator::AggOp::SUM), func)); CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_d")); - ctx->reduction_kernels.insert(std::make_pair("reduce_sum_d", func)); - // CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_f")); - // ctx->reduction_kernels.insert(std::make_pair("reduce_sum_f", func)); - // + ctx->reduction_kernels_d.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG, SpoofOperator::AggOp::SUM), func)); + // // SUM_SQ // CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_d")); // ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_d", func)); // CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_f")); // ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_f", func)); - // - // // MIN + + // MIN + CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_f")); + ctx->reduction_kernels_f.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG, SpoofOperator::AggOp::MIN), func)); CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_d")); - ctx->reduction_kernels.insert(std::make_pair("reduce_min_d", func)); - // CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_f")); - // ctx->reduction_kernels.insert(std::make_pair("reduce_min_f", func)); - // - // // MAX + ctx->reduction_kernels_d.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG, SpoofOperator::AggOp::MIN), func)); + + // MAX + CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_f")); + ctx->reduction_kernels_f.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG, SpoofOperator::AggOp::MAX), func)); CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_d")); - ctx->reduction_kernels.insert(std::make_pair("reduce_max_d", func)); - // CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_f")); - // ctx->reduction_kernels.insert(std::make_pair("reduce_max_f", func)); - + ctx->reduction_kernels_d.insert(std::make_pair(std::make_pair(SpoofOperator::AggType::FULL_AGG, SpoofOperator::AggOp::MAX), func)); + return reinterpret_cast<size_t>(ctx); } void SpoofCUDAContext::destroy_cuda(SpoofCUDAContext *ctx, uint32_t device_id) { - delete ctx; - ctx = nullptr; - // cuda device is handled by jCuda atm - //cudaDeviceReset(); + delete ctx; + // cuda device is handled by jCuda atm + //cudaDeviceReset(); } -int SpoofCUDAContext::compile(const std::string &src, const std::string &name, SpoofOperator::OpType op_type, - SpoofOperator::AggType agg_type, SpoofOperator::AggOp agg_op, SpoofOperator::RowType row_type, bool sparse_safe, - int32_t const_dim2, uint32_t num_vectors, bool TB1) { +int SpoofCUDAContext::compile(std::unique_ptr<SpoofOperator> op, const std::string &src) { #ifdef _DEBUG // std::cout << "---=== START source listing of spoof cuda kernel [ " << name << " ]: " << std::endl; // uint32_t line_num = 0; // std::istringstream src_stream(src); // for(std::string line; std::getline(src_stream, line); line_num++) // std::cout << line_num << ": " << line << std::endl; -// // std::cout << "---=== END source listing of spoof cuda kernel [ " << name << " ]." << std::endl; std::cout << "cwd: " << std::filesystem::current_path() << std::endl; std::cout << "include_paths: "; for_each (include_paths.begin(), include_paths.end(), [](const std::string& line){ std::cout << line << '\n';}); std::cout << std::endl; #endif -// uncomment all related lines for temporary timing output: -// auto start = clk::now(); +// uncomment all related lines for temporary timing output: // auto compile_start = clk::now(); - jitify::Program program = kernel_cache.program(src, 0, include_paths); + op->program = std::make_unique<jitify::Program>(kernel_cache.program(src, 0, include_paths)); // auto compile_end = clk::now(); // auto compile_duration = std::chrono::duration_cast<sec>(compile_end - compile_start).count(); - ops.insert(std::make_pair(name, SpoofOperator({std::move(program), op_type, agg_type, agg_op, row_type, name, - const_dim2, num_vectors, TB1, sparse_safe}))); - -// auto end = clk::now(); - -// auto handling_duration = std::chrono::duration_cast<sec>(end - start).count() - compile_duration; + compiled_ops.push_back(std::move(op)); // compile_total += compile_duration; -// handling_total += handling_duration; +// std::cout << name << " compiled in " +// << compile_duration << " seconds. Total compile time (abs/rel): " +// << compile_total << "/" << compiled_ops.size() << std::endl; + return compiled_ops.size() - 1; +} -// std::cout << name << " times [s] handling/compile/totals(h/c)/count: " -// << handling_duration << "/" -// << compile_duration << "/" -// << handling_total << "/" -// << compile_total << "/" << compile_count + 1 << std::endl; - return compile_count++; +template<typename T> +CUfunction SpoofCUDAContext::getReductionKernel(const std::pair<SpoofOperator::AggType, SpoofOperator::AggOp> &key) { + return nullptr; +} +template<> +CUfunction SpoofCUDAContext::getReductionKernel<float>(const std::pair<SpoofOperator::AggType, + SpoofOperator::AggOp> &key) { + return reduction_kernels_f[key]; } +template<> +CUfunction SpoofCUDAContext::getReductionKernel<double>(const std::pair<SpoofOperator::AggType, + SpoofOperator::AggOp> &key) { + return reduction_kernels_d[key]; +} \ No newline at end of file diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h index 89dd9e0..ab0f098 100644 --- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h +++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h @@ -25,403 +25,93 @@ #define NOMINMAX #endif -#include <cmath> -#include <cstdint> -#include <map> -#include <string> -#include <algorithm> - -#include "Matrix.h" - #ifndef NDEBUG -#define _DEBUG -#endif -#ifdef _DEBUG - #define JITIFY_PRINT_ALL 1 + #define _DEBUG #endif +//#ifdef _DEBUG +// #define JITIFY_PRINT_ALL 1 +//#endif -#include <jitify.hpp> +#include <map> +#include <string> #include <utility> #include <cublas_v2.h> - -#include "host_utils.h" +#include <jitify.hpp> +#include "SpoofOperator.h" +#include "Matrix.h" using jitify::reflection::type_of; -struct SpoofOperator { - enum class OpType : int { CW, RA, MA, OP, NONE }; - enum class AggType : int { NONE, NO_AGG, FULL_AGG, ROW_AGG, COL_AGG }; - enum class AggOp : int { NONE, SUM, SUM_SQ, MIN, MAX }; - enum class RowType : int { NONE, FULL_AGG = 4 }; - - jitify::Program program; - OpType op_type; - AggType agg_type; - AggOp agg_op; - RowType row_type; - const std::string name; - int32_t const_dim2; - uint32_t num_temp_vectors; - bool TB1 = false; - bool sparse_safe = true; -}; - class SpoofCUDAContext { jitify::JitCache kernel_cache; - std::map<const std::string, SpoofOperator> ops; + std::vector<std::unique_ptr<SpoofOperator>> compiled_ops; CUmodule reductions; - std::map<const std::string, CUfunction> reduction_kernels; - double handling_total, compile_total; - uint32_t compile_count; + std::map<std::pair<SpoofOperator::AggType, SpoofOperator::AggOp>, CUfunction> reduction_kernels_f; + std::map<std::pair<SpoofOperator::AggType, SpoofOperator::AggOp>, CUfunction> reduction_kernels_d; + +// double handling_total, compile_total; const std::string resource_path; const std::vector<std::string> include_paths; public: - // ToDo: make launch config more adaptive - // num threads - const int NT = 256; - - // values / thread - const int VT = 4; explicit SpoofCUDAContext(const char* resource_path_, std::vector<std::string> include_paths_) : reductions(nullptr), - resource_path(resource_path_), include_paths(std::move(include_paths_)), handling_total(0.0), compile_total(0.0), - compile_count(0) {} + resource_path(resource_path_), include_paths(std::move(include_paths_)) + //,handling_total(0.0), compile_total(0.0) + {} static size_t initialize_cuda(uint32_t device_id, const char* resource_path_); static void destroy_cuda(SpoofCUDAContext *ctx, uint32_t device_id); - - int compile(const std::string &src, const std::string &name, SpoofOperator::OpType op_type, - SpoofOperator::AggType agg_type = SpoofOperator::AggType::NONE, - SpoofOperator::AggOp agg_op = SpoofOperator::AggOp::NONE, - SpoofOperator::RowType row_type = SpoofOperator::RowType::NONE, bool sparse_safe = true, - int32_t const_dim2 = -1, uint32_t num_vectors = 0, bool TB1 = false); - - - template <typename T> - T execute_kernel(const std::string &name, std::vector<Matrix<T>>& input, std::vector<Matrix<T>>& sides, Matrix<T>* output, - T *scalars_ptr, uint32_t num_scalars, uint32_t grix) { - - T result = 0.0; - size_t dev_buf_size; - Matrix<T>* d_in = nullptr; - Matrix<T>* d_out = nullptr; - Matrix<T>* d_sides = nullptr; - T* b1_transposed = nullptr; - T *d_scalars = nullptr; - - auto o = ops.find(name); - if (o != ops.end()) { - SpoofOperator *op = &(o->second); - - // ToDo: multiple inputs for SpoofOuterProduct template - CHECK_CUDART(cudaMalloc((void **)&d_in, sizeof(Matrix<T>))); - CHECK_CUDART(cudaMemcpy(d_in, reinterpret_cast<void*>(&input[0]), sizeof(Matrix<T>), cudaMemcpyHostToDevice)); - - if(output != nullptr) { - if (op->sparse_safe && input.front().row_ptr != nullptr) { -#ifdef _DEBUG - std::cout << "copying sparse safe row ptrs" << std::endl; -#endif - CHECK_CUDART(cudaMemcpy(output->row_ptr, input.front().row_ptr, (input.front().rows+1)*sizeof(uint32_t), cudaMemcpyDeviceToDevice)); - } - - CHECK_CUDART(cudaMalloc((void **) &d_out, sizeof(Matrix<T>))); - //CHECK_CUDART(cudaMemset(out->data, 0, out->rows*out->cols*sizeof(T))); - CHECK_CUDART(cudaMemcpy(d_out, reinterpret_cast<void *>(output), sizeof(Matrix<T>), - cudaMemcpyHostToDevice)); - - } - else { - uint32_t num_blocks = 1; - if (op->op_type == SpoofOperator::OpType::CW) - num_blocks = std::ceil(((input.front().rows * input.front().cols) + NT * 2 - 1) / (NT * 2)); - - CHECK_CUDART(cudaMalloc((void **) &d_out, sizeof(Matrix<T>))); - T* d_out_data = nullptr; - CHECK_CUDART(cudaMalloc((void **) &d_out_data, sizeof(T) * num_blocks)); - Matrix<T> agg_out{d_out_data, 0, 0, num_blocks, 1, num_blocks}; - CHECK_CUDART(cudaMemcpy(d_out, reinterpret_cast<void *>(&agg_out), sizeof(Matrix<T>), - cudaMemcpyHostToDevice)); - } - - if (!sides.empty()) { - if(op->TB1) { -#ifdef _DEBUG - std::cout << "transposing TB1 for " << op->name << std::endl; -#endif - T* b1 = sides[0].data; - uint32_t m = sides[0].rows; - uint32_t n = sides[0].cols; - - cudaMalloc(reinterpret_cast<void**>(&b1_transposed), sizeof(T) * m * n); - double alpha = 1.0; - double beta = 0.0; - cublasHandle_t handle; - - CHECK_CUBLAS(cublasCreate(&handle)); - CHECK_CUBLAS(cublasDgeam(handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, &alpha, b1, n, &beta, b1, n, b1_transposed, m)); - sides[0].data = b1_transposed; - sides[0].rows = n; - sides[0].cols = m; - CHECK_CUBLAS(cublasDestroy(handle)); - } - dev_buf_size = sizeof(Matrix<T>) * sides.size(); - CHECK_CUDART(cudaMalloc(reinterpret_cast<void **>(&d_sides), dev_buf_size)); - CHECK_CUDART(cudaMemcpy(d_sides, &sides[0], dev_buf_size, cudaMemcpyHostToDevice)); - } - - if (num_scalars > 0) { - dev_buf_size = sizeof(T) * num_scalars; - CHECK_CUDART(cudaMalloc((void **)&d_scalars, dev_buf_size)); - CHECK_CUDART(cudaMemcpy(d_scalars, scalars_ptr, dev_buf_size, cudaMemcpyHostToDevice)); - } - - switch(op->op_type) { - case SpoofOperator::OpType::CW: - launch_cw_kernel(op, d_in, d_out, d_sides, sides.size(), d_scalars, input.front().rows, - input.front().cols, grix, input[0]); - break; - case SpoofOperator::OpType::RA: - launch_ra_kernel(op, d_in, d_out, d_sides, sides.size(), d_scalars, input.front().rows, - input.front().cols, grix, input[0].row_ptr!=nullptr); - break; - default: - throw std::runtime_error("error: unknown spoof operator"); - } - - if (num_scalars > 0) - CHECK_CUDART(cudaFree(d_scalars)); - - if (sides.size() > 0) - CHECK_CUDART(cudaFree(d_sides)); - - if (op->TB1) - CHECK_CUDART(cudaFree(b1_transposed)); - - if (op->agg_type == SpoofOperator::AggType::FULL_AGG || op->row_type == SpoofOperator::RowType::FULL_AGG) { - Matrix<T> res_mat; - CHECK_CUDART(cudaMemcpy(&res_mat, d_out, sizeof(Matrix<T>), cudaMemcpyDeviceToHost)); - CHECK_CUDART(cudaMemcpy(&result, res_mat.data, sizeof(T), cudaMemcpyDeviceToHost)); - - CHECK_CUDART(cudaFree(res_mat.data)); - CHECK_CUDART(cudaFree(d_out)); - } - } - else { - throw std::runtime_error("kernel " + name + " not found."); - } - return result; - } - - template<typename T> - std::string determine_agg_kernel(SpoofOperator* op) { - std::string reduction_kernel_name; - std::string reduction_type; - std::string suffix = (typeid(T) == typeid(double) ? "_d" : "_f"); - switch (op->agg_type) { - case SpoofOperator::AggType::FULL_AGG: - reduction_type = "_"; - break; - case SpoofOperator::AggType::ROW_AGG: - reduction_type = "_row_"; - break; - case SpoofOperator::AggType::COL_AGG: - reduction_type = "_col_"; - break; - default: - throw std::runtime_error("unknown reduction type"); - return ""; - } - switch (op->agg_op) { - case SpoofOperator::AggOp::MIN: - reduction_kernel_name = "reduce" + reduction_type + "min" + suffix; - break; - case SpoofOperator::AggOp::MAX: - reduction_kernel_name = "reduce" + reduction_type + "max" + suffix; - break; - case SpoofOperator::AggOp::SUM_SQ: - reduction_kernel_name = "reduce" + reduction_type + "sum_sq" + suffix; - break; - case SpoofOperator::AggOp::SUM: - reduction_kernel_name = "reduce" + reduction_type + "sum" + suffix; - break; - default: - throw std::runtime_error("unknown reduction op"); - } + int compile(std::unique_ptr<SpoofOperator> op, const std::string &src); - return reduction_kernel_name; - } + template <typename T, typename CALL> + int launch(uint32_t opID, std::vector<Matrix<T>>& input, std::vector<Matrix<T>>& sides, Matrix<T>& output, + T* scalars, uint32_t grix) { + // dp holds in/side/out/scalar pointers for GPU + DevMatPtrs<T> dp; - template<typename T> - void launch_cw_kernel(SpoofOperator* op, Matrix<T>* d_in, Matrix<T>* d_out, Matrix<T>* d_sides, uint32_t num_sides, - T* d_scalars, uint32_t in_rows, uint32_t in_cols, uint32_t grix, const Matrix<T>& h_in) { - T value_type; - bool sparse = h_in.row_ptr != nullptr; - uint32_t N = in_rows * in_cols; - std::string op_name(op->name + "_DENSE"); - if(sparse) { - op_name = std::string(op->name + "_SPARSE"); - N = h_in.nnz; - } + SpoofOperator* op = compiled_ops[opID].get(); - switch (op->agg_type) { - case SpoofOperator::AggType::FULL_AGG: { - // num ctas - uint32_t NB = std::ceil((N + NT * 2 - 1) / (NT * 2)); - dim3 grid(NB, 1, 1); - dim3 block(NT, 1, 1); - uint32_t shared_mem_size = NT * sizeof(T); + CHECK_CUDART(cudaMalloc((void **)&dp.in, sizeof(Matrix<T>) * input.size())); + CHECK_CUDART(cudaMemcpy(dp.in, reinterpret_cast<void*>(&input[0]), sizeof(Matrix<T>) * input.size(), + cudaMemcpyHostToDevice)); -#ifdef _DEBUG - // ToDo: connect output to SystemDS logging facilities - std::cout << "launching spoof cellwise kernel " << op_name << " with " - << NT * NB << " threads in " << NB << " blocks and " - << shared_mem_size - << " bytes of shared memory for full aggregation of " - << N << " elements" - << std::endl; -#endif - - CHECK_CUDA(op->program.kernel(op_name) - .instantiate(type_of(value_type), std::max(1u, num_sides)) - .configure(grid, block, shared_mem_size) - .launch(d_in, d_sides, d_out, d_scalars, N, grix)); - - if(NB > 1) { - std::string reduction_kernel_name = determine_agg_kernel<T>(op); - CUfunction reduce_kernel = reduction_kernels.find(reduction_kernel_name)->second; - N = NB; - uint32_t iter = 1; - while (NB > 1) { - void* args[3] = { &d_out, &d_out, &N}; - - NB = std::ceil((N + NT * 2 - 1) / (NT * 2)); -#ifdef _DEBUG - std::cout << "agg iter " << iter++ << " launching spoof cellwise kernel " << op_name << " with " - << NT * NB << " threads in " << NB << " blocks and " - << shared_mem_size - << " bytes of shared memory for full aggregation of " - << N << " elements" - << std::endl; -#endif - CHECK_CUDA(cuLaunchKernel(reduce_kernel, - NB, 1, 1, - NT, 1, 1, - shared_mem_size, nullptr, args, nullptr)); - N = NB; - } - } - break; - } - case SpoofOperator::AggType::COL_AGG: { - // num ctas - uint32_t NB = std::ceil((N + NT - 1) / NT); - dim3 grid(NB, 1, 1); - dim3 block(NT, 1, 1); - uint32_t shared_mem_size = 0; -#ifdef _DEBUG - std::cout << " launching spoof cellwise kernel " << op_name << " with " - << NT * NB << " threads in " << NB << " blocks for column aggregation of " - << N << " elements" << std::endl; -#endif - CHECK_CUDA(op->program.kernel(op_name) - .instantiate(type_of(value_type), std::max(1u, num_sides)) - .configure(grid, block, shared_mem_size) - .launch(d_in, d_sides, d_out, d_scalars, N, grix)); - - break; + if (!sides.empty()) { + CHECK_CUDART(cudaMalloc(reinterpret_cast<void **>(&dp.sides), sizeof(Matrix<T>) * sides.size())); + CHECK_CUDART(cudaMemcpy(dp.sides, &sides[0], sizeof(Matrix<T>) * sides.size(), cudaMemcpyHostToDevice)); } - case SpoofOperator::AggType::ROW_AGG: { - // num ctas - uint32_t NB = in_rows; - dim3 grid(NB, 1, 1); - dim3 block(NT, 1, 1); - uint32_t shared_mem_size = NT * sizeof(T); -#ifdef _DEBUG - std::cout << " launching spoof cellwise kernel " << op_name << " with " - << NT * NB << " threads in " << NB << " blocks and " - << shared_mem_size << " bytes of shared memory for row aggregation of " - << N << " elements" << std::endl; -#endif - CHECK_CUDA(op->program.kernel(op_name) - .instantiate(type_of(value_type), std::max(1u, num_sides)) - .configure(grid, block, shared_mem_size) - .launch(d_in, d_sides, d_out, d_scalars, N, grix)); - - break; + + if (op->isSparseSafe() && input.front().row_ptr != nullptr) { + CHECK_CUDART(cudaMemcpy(output.row_ptr, input.front().row_ptr, (input.front().rows+1)*sizeof(uint32_t), + cudaMemcpyDeviceToDevice)); } - case SpoofOperator::AggType::NO_AGG: - default: { - // num ctas - // ToDo: VT not a template parameter anymore - uint32_t NB = std::ceil((N + NT * VT - 1) / (NT * VT)); - if(sparse) - NB = in_rows; - dim3 grid(NB, 1, 1); - dim3 block(NT, 1, 1); - uint32_t shared_mem_size = 0; - #ifdef _DEBUG - if(sparse) { - std::cout << "launching sparse spoof cellwise kernel " << op_name << " with " << NT * NB - << " threads in " << NB << " blocks without aggregation for " << N << " elements" - << std::endl; - } - else { - std::cout << "launching spoof cellwise kernel " << op_name << " with " << NT * NB - << " threads in " << NB << " blocks without aggregation for " << N << " elements" - << std::endl; - } + std::cout << "output rows: " << output.rows << " cols: " << output.cols << " nnz: " << output.nnz << " format: " << + (output.row_ptr == nullptr ? "dense" : "sparse") << std::endl; #endif + size_t out_num_elements = output.rows * output.cols; + if(output.row_ptr) + if(op->isSparseSafe() && output.nnz > 0) + out_num_elements = output.nnz; + CHECK_CUDART(cudaMalloc((void **) &dp.out, sizeof(Matrix<T>))); + CHECK_CUDART(cudaMemset(output.data, 0, out_num_elements * sizeof(T))); + CHECK_CUDART(cudaMemcpy(dp.out, reinterpret_cast<void *>(&output), sizeof(Matrix<T>), + cudaMemcpyHostToDevice)); + + dp.scalars = scalars; - CHECK_CUDA(op->program.kernel(op_name) - .instantiate(type_of(value_type), std::max(1u, num_sides)) - .configure(grid, block, shared_mem_size) - .launch(d_in, d_sides, d_out, d_scalars, N, grix)); - } - } + CALL::exec(this, op, input, sides, output, grix, dp); + + return 0; } + + std::string getOperatorName(uint32_t opID) { return compiled_ops.at(opID)->name; } template<typename T> - void launch_ra_kernel(SpoofOperator* op, Matrix<T>* d_in, Matrix<T>* d_out, Matrix<T>* d_sides, uint32_t num_sides, - T* d_scalars, uint32_t in_rows, uint32_t in_cols, uint32_t grix, bool sparse) { - T value_type; - dim3 grid(in_rows, 1, 1); - dim3 block(NT, 1, 1); - unsigned int shared_mem_size = NT * sizeof(T); - - uint32_t tmp_len = 0; - uint32_t temp_buf_size = 0; - T* d_temp = nullptr; - if(op->num_temp_vectors>0) { - tmp_len = std::max(in_cols, op->const_dim2 < 0 ? 0 : static_cast<uint32_t>(op->const_dim2)); - temp_buf_size = op->num_temp_vectors * tmp_len * in_rows * sizeof(T); - CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size)); - CHECK_CUDART(cudaMemset(d_temp, 0, temp_buf_size)); - } - - std::string name(op->name + "_DENSE"); - if(sparse) - name = std::string(op->name + "_SPARSE"); - -#ifdef _DEBUG - // ToDo: connect output to SystemDS logging facilities - std::cout << "launching spoof rowwise kernel " << name << " with " << NT * in_rows << " threads in " << in_rows - << " blocks and " << shared_mem_size << " bytes of shared memory for " << in_cols << " cols processed by " - << NT << " threads per row, adding " << temp_buf_size / 1024 << " kb of temp buffer in global memory." << std::endl; -#endif - CHECK_CUDA(op->program.kernel(name) - .instantiate(type_of(value_type), std::max(1u, num_sides), op->num_temp_vectors, tmp_len) - .configure(grid, block, shared_mem_size) - .launch(d_in, d_sides, d_out, d_scalars, d_temp, grix)); - - if(op->num_temp_vectors>0) - CHECK_CUDART(cudaFree(d_temp)); - } + CUfunction getReductionKernel(const std::pair<SpoofOperator::AggType, SpoofOperator::AggOp>& key); }; #endif // SPOOFCUDACONTEXT_H diff --git a/src/main/cuda/spoof-launcher/SpoofCellwise.h b/src/main/cuda/spoof-launcher/SpoofCellwise.h new file mode 100644 index 0000000..f1735eb --- /dev/null +++ b/src/main/cuda/spoof-launcher/SpoofCellwise.h @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once +#ifndef SYSTEMDS_SPOOFCELLWISE_H +#define SYSTEMDS_SPOOFCELLWISE_H + +#include "SpoofCUDAContext.h" +#include <algorithm> + +template<typename T> +struct SpoofCellwiseFullAgg { + + static void exec(SpoofCellwiseOp* op, uint32_t NT, uint32_t N, const std::string& op_name, + std::vector<Matrix<T>>& sides, uint32_t grix, DevMatPtrs<T>& dp) { + T value_type; + + // num ctas + uint32_t NB = std::ceil((N + NT * 2 - 1) / (NT * 2)); + dim3 grid(NB, 1, 1); + dim3 block(NT, 1, 1); + uint32_t shared_mem_size = NT * sizeof(T); +#ifdef _DEBUG + // ToDo: connect output to SystemDS logging facilities + std::cout << "launching spoof cellwise kernel " << op_name << " with " + << NT * NB << " threads in " << NB << " blocks and " + << shared_mem_size + << " bytes of shared memory for full aggregation of " + << N << " elements" + << std::endl; +#endif + CHECK_CUDA(op->program.get()->kernel(op_name) + .instantiate(type_of(value_type), std::max(1ul, sides.size())) + .configure(grid, block, shared_mem_size) + .launch(dp.in, dp.sides, dp.out, dp.scalars, N, grix)); + + if(NB > 1) { + N = NB; + while (NB > 1) { + void* args[3] = { &dp.out, &dp.out, &N}; + + NB = std::ceil((N + NT * 2 - 1) / (NT * 2)); +#ifdef _DEBUG + std::cout << " launching spoof cellwise kernel " << op_name << " with " + << NT * NB << " threads in " << NB << " blocks and " + << shared_mem_size + << " bytes of shared memory for full aggregation of " + << N << " elements" + << std::endl; +#endif + CHECK_CUDA(cuLaunchKernel(op->agg_kernel,NB, 1, 1, NT, 1, 1, shared_mem_size, nullptr, args, nullptr)); + N = NB; + } + } + } +}; + + +template<typename T> +struct SpoofCellwiseRowAgg { + static void exec(SpoofOperator *op, uint32_t NT, uint32_t N, const std::string &op_name, + std::vector<Matrix<T>> &input, std::vector<Matrix<T>> &sides, uint32_t grix, DevMatPtrs<T>& dp) { + T value_type; + + // num ctas + uint32_t NB = input.front().rows; + dim3 grid(NB, 1, 1); + dim3 block(NT, 1, 1); + uint32_t shared_mem_size = NT * sizeof(T); +#ifdef _DEBUG + std::cout << " launching spoof cellwise kernel " << op_name << " with " + << NT * NB << " threads in " << NB << " blocks and " + << shared_mem_size << " bytes of shared memory for row aggregation of " + << N << " elements" << std::endl; +#endif + CHECK_CUDA(op->program->kernel(op_name) + .instantiate(type_of(value_type), std::max(1ul, sides.size())) + .configure(grid, block, shared_mem_size) + .launch(dp.in, dp.sides, dp.out, dp.scalars, N, grix)); + + } +}; + + +template<typename T> +struct SpoofCellwiseColAgg { + static void exec(SpoofOperator* op, uint32_t NT, uint32_t N, const std::string& op_name, + std::vector<Matrix<T>>& sides, uint32_t grix, DevMatPtrs<T>& dp) { + T value_type; + + // num ctas + uint32_t NB = std::ceil((N + NT - 1) / NT); + + dim3 grid(NB,1, 1); + dim3 block(NT,1, 1); + uint32_t shared_mem_size = 0; +#ifdef _DEBUG + std::cout << " launching spoof cellwise kernel " << op_name << " with " + << NT * NB << " threads in " << NB << " blocks for column aggregation of " + << N << " elements" << std::endl; +#endif + CHECK_CUDA(op->program->kernel(op_name) + .instantiate(type_of(value_type), std::max(1ul, sides.size())) + .configure(grid, block, shared_mem_size) + .launch(dp.in, dp.sides, dp.out, dp.scalars, N, grix)); + + } +}; + + +template<typename T> +struct SpoofCellwiseNoAgg { + static void exec(SpoofOperator *op, uint32_t NT, uint32_t N, const std::string &op_name, + std::vector<Matrix<T>> &input, std::vector<Matrix<T>> &sides, uint32_t grix, DevMatPtrs<T>& dp) { + T value_type; + bool sparse_input = input.front().row_ptr != nullptr; + + // num ctas + // ToDo: adaptive VT + const uint32_t VT = 4; + uint32_t NB = std::ceil((N + NT * VT - 1) / (NT * VT)); + if(sparse_input) + NB = input.front().rows; + dim3 grid(NB, 1, 1); + dim3 block(NT, 1, 1); + uint32_t shared_mem_size = 0; + +#ifdef _DEBUG + if(sparse_input) { + std::cout << "launching sparse spoof cellwise kernel " << op_name << " with " << NT * NB + << " threads in " << NB << " blocks without aggregation for " << N << " elements" + << std::endl; + } + else { + std::cout << "launching spoof cellwise kernel " << op_name << " with " << NT * NB + << " threads in " << NB << " blocks without aggregation for " << N << " elements" + << std::endl; + } +#endif + + CHECK_CUDA(op->program->kernel(op_name) + .instantiate(type_of(value_type), std::max(1ul, sides.size())) + .configure(grid, block, shared_mem_size) + .launch(dp.in, dp.sides, dp.out, dp.scalars, N, grix)); + } +}; + +template<typename T> +struct SpoofCellwise { + static void exec(SpoofCUDAContext* ctx, SpoofOperator* _op, std::vector<Matrix<T>>& input, + std::vector<Matrix<T>>& sides, Matrix<T>& output, uint32_t grix, + DevMatPtrs<T>& dp) { + + T value_type; + auto* op = dynamic_cast<SpoofCellwiseOp*>(_op); + bool sparse_input = input.front().row_ptr != nullptr; + uint32_t NT = 256; // ToDo: num threads + uint32_t N = input.front().rows * input.front().cols; + std::string op_name(op->name + "_DENSE"); + if(sparse_input) { + op_name = std::string(op->name + "_SPARSE"); + if(op->isSparseSafe() && input.front().nnz > 0) + N = input.front().nnz; + } + + switch(op->agg_type) { + case SpoofOperator::AggType::FULL_AGG: + op->agg_kernel = ctx->template getReductionKernel<T>(std::make_pair(op->agg_type, op->agg_op)); + SpoofCellwiseFullAgg<T>::exec(op, NT, N, op_name, sides, grix, dp); + break; + case SpoofOperator::AggType::ROW_AGG: + SpoofCellwiseRowAgg<T>::exec(op, NT, N, op_name, input, sides, grix, dp); + break; + case SpoofOperator::AggType::COL_AGG: + SpoofCellwiseColAgg<T>::exec(op, NT, N, op_name, sides, grix, dp); + break; + case SpoofOperator::AggType::NO_AGG: + SpoofCellwiseNoAgg<T>::exec(op, NT, N, op_name, input, sides, grix, dp); + break; + default: + throw std::runtime_error("unknown cellwise agg type" + std::to_string(static_cast<int>(op->agg_type))); + } + } +}; + + +#endif //SYSTEMDS_SPOOFCELLWISE_H diff --git a/src/main/cuda/spoof-launcher/SpoofOperator.h b/src/main/cuda/spoof-launcher/SpoofOperator.h new file mode 100644 index 0000000..f9fc5ee --- /dev/null +++ b/src/main/cuda/spoof-launcher/SpoofOperator.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once +#ifndef SYSTEMDS_SPOOFOPERATOR_H +#define SYSTEMDS_SPOOFOPERATOR_H + +#include <cmath> +#include <cstdint> +#include <string> +#include <jitify.hpp> +#include "host_utils.h" +#include "Matrix.h" + +struct SpoofOperator { +// enum class OpType : int { CW, RA, MA, OP, NONE }; + enum class AggType : int { NO_AGG, FULL_AGG, ROW_AGG, COL_AGG }; + enum class AggOp : int { SUM, SUM_SQ, MIN, MAX }; + enum class RowType : int { FULL_AGG = 4 }; + +// OpType op_type; + std::string name; +// jitify::Program program; + std::unique_ptr<jitify::Program> program; + + [[nodiscard]] virtual bool isSparseSafe() const = 0; +}; + +struct SpoofCellwiseOp : public SpoofOperator { + bool sparse_safe; + AggType agg_type; + AggOp agg_op; + CUfunction agg_kernel{}; + SpoofCellwiseOp(AggType at, AggOp ao, bool ss) : agg_type(at), agg_op(ao), sparse_safe(ss) {} + + [[nodiscard]] bool isSparseSafe() const override { return sparse_safe; } +}; + +struct SpoofRowwiseOp : public SpoofOperator { + bool TB1 = false; + uint32_t num_temp_vectors; + int32_t const_dim2; + RowType row_type; + + SpoofRowwiseOp(RowType rt, bool tb1, uint32_t ntv, int32_t cd2) : row_type(rt), TB1(tb1), num_temp_vectors(ntv), + const_dim2(cd2) {} + + [[nodiscard]] bool isSparseSafe() const override { return false; } +}; + +template<typename T> +struct DevMatPtrs { + Matrix<T>* ptrs[3] = {0,0,0}; + + Matrix<T>*& in = ptrs[0]; + Matrix<T>*& sides = ptrs[1]; + Matrix<T>*& out = ptrs[2]; + T* scalars{}; + + ~DevMatPtrs() { +#ifdef _DEBUG + std::cout << "~DevMatPtrs() before cudaFree:\n"; + int i = 0; + for (auto& p : ptrs) { + std::cout << " p[" << i << "]=" << p; + i++; + } + std::cout << std::endl; +#endif + for (auto& p : ptrs) { + if (p) { + CHECK_CUDART(cudaFree(p)); + p = nullptr; + } + } +#ifdef _DEBUG + std::cout << "~DevMatPtrs() after cudaFree:\n"; + i = 0; + for (auto& p : ptrs) { + std::cout << " p[" << i << "]=" << p; + i++; + } + std::cout << std::endl; +#endif + } +}; + +#endif //SYSTEMDS_SPOOFOPERATOR_H diff --git a/src/main/cuda/spoof-launcher/SpoofRowwise.h b/src/main/cuda/spoof-launcher/SpoofRowwise.h new file mode 100644 index 0000000..fb919b7 --- /dev/null +++ b/src/main/cuda/spoof-launcher/SpoofRowwise.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once +#ifndef SYSTEMDS_SPOOFROWWISE_H +#define SYSTEMDS_SPOOFROWWISE_H + +#include "SpoofCUDAContext.h" +#include <algorithm> + +template <typename T> +struct SpoofRowwise { + + static void exec([[maybe_unused]] SpoofCUDAContext* ctx, SpoofOperator* _op, std::vector<Matrix<T>>& input, + std::vector<Matrix<T>>& sides, Matrix<T>& output, uint32_t grix, DevMatPtrs<T>& dp) { + uint32_t NT=256; + T value_type; + bool sparse_input = input.front().row_ptr != nullptr; + auto* op = dynamic_cast<SpoofRowwiseOp*>(_op); + dim3 grid(input.front().rows, 1, 1); + dim3 block(NT, 1, 1); + unsigned int shared_mem_size = NT * sizeof(T); + + uint32_t tmp_len = 0; + uint32_t temp_buf_size = 0; + T* d_temp = nullptr; + if(op->num_temp_vectors > 0) { + tmp_len = std::max(input.front().cols, op->const_dim2 < 0 ? 0 : static_cast<uint32_t>(op->const_dim2)); + temp_buf_size = op->num_temp_vectors * tmp_len * input.front().rows * sizeof(T); +#ifdef _DEBUG + std::cout << "num_temp_vect: " << op->num_temp_vectors << " temp_buf_size: " << temp_buf_size << " tmp_len: " << tmp_len << std::endl; +#endif + CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size)); + CHECK_CUDART(cudaMemset(d_temp, 0, temp_buf_size)); + } + + std::string op_name(op->name + "_DENSE"); + if(sparse_input) + op_name = std::string(op->name + "_SPARSE"); + +#ifdef _DEBUG + // ToDo: connect output to SystemDS logging facilities + std::cout << "launching spoof rowwise kernel " << op_name << " with " << NT * input.front().rows << " threads in " + << input.front().rows << " blocks and " << shared_mem_size << " bytes of shared memory for " + << input.front().rows << " cols processed by " << NT << " threads per row, adding " + << temp_buf_size / 1024 << " kb of temp buffer in global memory." << std::endl; +#endif + CHECK_CUDA(op->program->kernel(op_name) + .instantiate(type_of(value_type), std::max(1ul, sides.size()), op->num_temp_vectors, tmp_len) + .configure(grid, block, shared_mem_size) + .launch(dp.in, dp.sides, dp.out, dp.scalars, d_temp, grix)); + + if(op->num_temp_vectors > 0) + CHECK_CUDART(cudaFree(d_temp)); + +// if (op->TB1) +// CHECK_CUDART(cudaFree(b1_transposed)); + } +}; + +#endif //SYSTEMDS_SPOOFROWWISE_H diff --git a/src/main/cuda/spoof-launcher/host_utils.h b/src/main/cuda/spoof-launcher/host_utils.h index e36e764..fe8c3b5 100644 --- a/src/main/cuda/spoof-launcher/host_utils.h +++ b/src/main/cuda/spoof-launcher/host_utils.h @@ -29,7 +29,7 @@ if (status != CUDA_SUCCESS) { \ const char* str; \ cuGetErrorName(status, &str); \ - std::cout << "(CUDA) returned " << str; \ + std::cout << "(CUDA) returned: " << str; \ std::cout << " (" << __FILE__ << ":" << __LINE__ << ":" << __func__ \ << "())" << std::endl; \ } \ @@ -39,7 +39,7 @@ do { \ cudaError_t status = call; \ if (status != cudaSuccess) { \ - std::cout << "(CUDART) returned " << cudaGetErrorString(status); \ + std::cout << "(CUDART) returned: " << cudaGetErrorString(status); \ std::cout << " (" << __FILE__ << ":" << __LINE__ << ":" << __func__ \ << "())" << std::endl; \ } \ diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp b/src/main/cuda/spoof-launcher/jni_bridge.cpp index fd95d1b..9942b41 100644 --- a/src/main/cuda/spoof-launcher/jni_bridge.cpp +++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp @@ -19,7 +19,8 @@ #include "jni_bridge.h" #include "SpoofCUDAContext.h" -#include "Matrix.h" +#include "SpoofCellwise.h" +#include "SpoofRowwise.h" // JNI Methods to get/release arrays #define GET_ARRAY(env, input)((void *)env->GetPrimitiveArrayCritical(input, nullptr)) @@ -27,205 +28,169 @@ #define RELEASE_ARRAY(env, java, cpp)(env->ReleasePrimitiveArrayCritical(java, cpp, 0)) // error output helper -void printStdException(JNIEnv *env, jstring name, const std::exception& e, bool compile = false) { +void printException(const std::string& name, const std::exception& e, bool compile = false) { std::string type = compile ? "compiling" : "executing"; - if(name != nullptr) { - const char *cstr_name = env->GetStringUTFChars(name, nullptr); - std::cout << "std::exception while " << type << " SPOOF CUDA operator " << cstr_name << ":\n" << e.what() << - std::endl; - env->ReleaseStringUTFChars(name, cstr_name); - } - else - std::cout << "std::exception while " << type << " SPOOF CUDA operator (name=nullptr):\n" << e.what() << - std::endl; -} - -void printException(JNIEnv* env, jstring name, bool compile = false) { - std::string type = compile ? "compiling" : "executing"; - if(name != nullptr) { - const char *cstr_name = env->GetStringUTFChars(name, nullptr); - std::cout << "Unknown exception occurred while " << type << " SPOOF CUDA operator " << cstr_name << std::endl; - env->ReleaseStringUTFChars(name, cstr_name); - } - else - std::cout << "Unknown exception occurred while " << type << " SPOOF CUDA operator (name=nullptr)" << std::endl; + std::cout << "std::exception while " << type << " SPOOF CUDA operator " << name << ":\n" << e.what() << std::endl; } -void printSource(JNIEnv* env, jstring name, jstring src) { - if(src != nullptr) { - const char *cstr_src = env->GetStringUTFChars(src, nullptr); - std::cout << "Source code:\n" << cstr_src << std::endl; - env->ReleaseStringUTFChars(name, cstr_src); - } -} -JNIEXPORT jlong JNICALL +// a pod struct to have names for the passed pointers +template<typename T> +struct LaunchMetadata { + const T& opID; + const T& grix; + const T& num_inputs; + const T& num_sides; + + // num entries describing one matrix (6 entries): + // {nnz,rows,cols,row_ptr,col_idxs,data} + const T& entry_size; + const T& num_scalars; + + explicit LaunchMetadata(const size_t* jvals) : opID(jvals[0]), grix(jvals[1]), num_inputs(jvals[2]), + num_sides(jvals[3]), entry_size(jvals[4]), num_scalars(jvals[5]) {} +}; + + +[[maybe_unused]] JNIEXPORT jlong JNICALL Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context( - JNIEnv *env, jobject jobj, jint device_id, jstring resource_path) { - const char *cstr_rp = env->GetStringUTFChars(resource_path, nullptr); - size_t ctx = SpoofCUDAContext::initialize_cuda(device_id, cstr_rp); - env->ReleaseStringUTFChars(resource_path, cstr_rp); - return ctx; + JNIEnv *jenv, [[maybe_unused]] jobject jobj, jint device_id, jstring resource_path) { + const char *cstr_rp = jenv->GetStringUTFChars(resource_path, nullptr); + size_t ctx = SpoofCUDAContext::initialize_cuda(device_id, cstr_rp); + jenv->ReleaseStringUTFChars(resource_path, cstr_rp); + return ctx; } -JNIEXPORT void JNICALL + +[[maybe_unused]] JNIEXPORT void JNICALL Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context( - JNIEnv *env, jobject jobj, jlong ctx, jint device_id) { - SpoofCUDAContext::destroy_cuda(reinterpret_cast<SpoofCUDAContext *>(ctx), device_id); + [[maybe_unused]] JNIEnv *jenv, [[maybe_unused]] jobject jobj, jlong ctx, jint device_id) { + SpoofCUDAContext::destroy_cuda(reinterpret_cast<SpoofCUDAContext *>(ctx), device_id); } -JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc - (JNIEnv *env, jobject jobj, jlong ctx, jstring name, jstring src, jint type, jint agg_op, jboolean sparseSafe) { + +template<typename TEMPLATE> +int compile_spoof_operator(JNIEnv *jenv, [[maybe_unused]] jobject jobj, jlong _ctx, jstring name, jstring src, TEMPLATE op) { + std::string operator_name; try { - auto *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx); - const char *cstr_name = env->GetStringUTFChars(name, nullptr); - const char *cstr_src = env->GetStringUTFChars(src, nullptr); - - int result = ctx_->compile(cstr_src, cstr_name, SpoofOperator::OpType::CW, - SpoofOperator::AggType(type), SpoofOperator::AggOp(agg_op), SpoofOperator::RowType::NONE, sparseSafe); + auto *ctx = reinterpret_cast<SpoofCUDAContext *>(_ctx); + const char *cstr_name = jenv->GetStringUTFChars(name, nullptr); + const char *cstr_src = jenv->GetStringUTFChars(src, nullptr); + operator_name = cstr_name; + + op->name = operator_name; + int status = ctx->compile(std::move(op), cstr_src); - env->ReleaseStringUTFChars(src, cstr_src); - env->ReleaseStringUTFChars(name, cstr_name); - return result; + jenv->ReleaseStringUTFChars(src, cstr_src); + jenv->ReleaseStringUTFChars(name, cstr_name); + return status; } catch (std::exception& e) { - printStdException(env, name, e, true); + printException(operator_name, e, true); } catch (...) { - printException(env, name, true); + printException(operator_name, std::runtime_error("unknown exception"), true); } - printSource(env, name, src); return -1; } -/* - * Class: org_apache_sysds_hops_codegen_cplan_CNodeRow - * Method: compile_nvrtc - * Signature: (JLjava/lang/String;Ljava/lang/String;IIIZ)I - */ -JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc - (JNIEnv *env, jobject jobj, jlong ctx, jstring name, jstring src, jint type, jint const_dim2, jint num_vectors, - jboolean TB1) { + +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc + (JNIEnv *jenv, jobject jobj, jlong ctx, jstring name, jstring src, jint type, jint agg_op, + jboolean sparseSafe) { + + std::unique_ptr<SpoofCellwiseOp> op = std::make_unique<SpoofCellwiseOp>(SpoofOperator::AggType(type), + SpoofOperator::AggOp(agg_op), sparseSafe); + + return compile_spoof_operator<std::unique_ptr<SpoofCellwiseOp>>(jenv, jobj, ctx, name, src, std::move(op)); +} + + +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc + (JNIEnv *jenv, jobject jobj, jlong ctx, jstring name, jstring src, jint type, jint const_dim2, + jint num_vectors, jboolean TB1) { + + std::unique_ptr<SpoofRowwiseOp> op = std::make_unique<SpoofRowwiseOp>(SpoofOperator::RowType(type), TB1, + num_vectors, const_dim2); + return compile_spoof_operator<std::unique_ptr<SpoofRowwiseOp>>(jenv, jobj, ctx, name, src, std::move(op)); +} + + +template<typename T, typename TEMPLATE> +int launch_spoof_operator(JNIEnv *jenv, [[maybe_unused]] jclass jobj, jlong _ctx, jlongArray _meta, jlongArray in, + jlongArray _sides, jlongArray out, jlong _scalars) { + std::string operator_name("unknown"); try { - auto *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx); - const char *cstr_name = env->GetStringUTFChars(name, nullptr); - const char *cstr_src = env->GetStringUTFChars(src, nullptr); + // retrieve data handles from JVM + auto *metacast = reinterpret_cast<size_t *>(GET_ARRAY(jenv, _meta)); + auto *ctx = reinterpret_cast<SpoofCUDAContext *>(_ctx); + auto *inputs = reinterpret_cast<size_t *>(GET_ARRAY(jenv, in)); + auto *sides = reinterpret_cast<size_t *>(GET_ARRAY(jenv, _sides)); + auto *output = reinterpret_cast<size_t *>(GET_ARRAY(jenv, out)); +// auto *scalars = reinterpret_cast<T *>(GET_ARRAY(jenv, _scalars)); + auto *scalars = reinterpret_cast<T *>(_scalars); + LaunchMetadata<size_t> meta(metacast); + + // this implicitly checks if op exists + operator_name = ctx->getOperatorName(meta.opID); + + // wrap/cast inputs + std::vector<Matrix<T>> mats_in; + for(auto i = 0; i < meta.num_inputs; i+=meta.entry_size) + mats_in.emplace_back(&inputs[i]); + + // wrap/cast sides + std::vector<Matrix<T>> mats_sides; + for(auto i = 0; i < meta.num_sides; i+=meta.entry_size) + mats_sides.emplace_back(&sides[i]); + + // wrap/cast output + Matrix<T> mat_out(output); - int result = ctx_->compile(cstr_src, cstr_name, SpoofOperator::OpType::RA, - SpoofOperator::AggType::NONE, SpoofOperator::AggOp::NONE, - SpoofOperator::RowType(type), true, const_dim2, num_vectors, TB1); + // wrap/cast scalars +// std::unique_ptr<Matrix<T>> mat_scalars = scalars == nullptr ? 0 : std::make_unique<Matrix<T>>(scalars); - env->ReleaseStringUTFChars(src, cstr_src); - env->ReleaseStringUTFChars(name, cstr_name); - return result; + // transfers resource pointers to GPU and calls op->exec() + ctx->launch<T, TEMPLATE>(meta.opID, mats_in, mats_sides, mat_out, scalars, meta.grix); + + // release data handles from JVM + RELEASE_ARRAY(jenv, _meta, metacast); + RELEASE_ARRAY(jenv, in, inputs); + RELEASE_ARRAY(jenv, _sides, sides); + RELEASE_ARRAY(jenv, out, output); +// RELEASE_ARRAY(jenv, _scalars, scalars); + + return 0; } catch (std::exception& e) { - printStdException(env, name, e); + printException(operator_name, e); } catch (...) { - printException(env, name); + printException(operator_name, std::runtime_error("unknown exception")); } - printSource(env, name, src); return -1; } +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f + (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, jlongArray in, jlongArray sides, jlongArray out, + jlong scalars) { + return launch_spoof_operator<float, SpoofCellwise<float>>(jenv, jobj, ctx, meta, in, sides, out, scalars); +} -JNIEXPORT jdouble JNICALL -Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1d( - JNIEnv *env, jobject jobj, jlong ctx, jstring name, jlongArray in_ptrs, jint input_offset, jlongArray side_ptrs, - jlongArray out_ptrs, jdoubleArray scalars_, jlong grix, jobject inputs_, jobject out_obj) { - double result = 0.0; - try { - auto *ctx_ = reinterpret_cast<SpoofCUDAContext*>(ctx); - const char *cstr_name = env->GetStringUTFChars(name, nullptr); - - auto* inputs = reinterpret_cast<size_t*>(GET_ARRAY(env, in_ptrs)); - auto* sides = reinterpret_cast<size_t*>(GET_ARRAY(env, side_ptrs)); - auto *output = reinterpret_cast<size_t*>(GET_ARRAY(env, out_ptrs)); - auto *scalars = reinterpret_cast<double*>(GET_ARRAY(env, scalars_)); - - //ToDo: call once while init - jclass CacheableData = env->FindClass("org/apache/sysds/runtime/controlprogram/caching/CacheableData"); - if (!CacheableData) { - std::cerr << " JNIEnv -> FindClass(CacheableData) failed" << std::endl; - return -1.0; - } - jclass ArrayList = env->FindClass("java/util/ArrayList"); - if (!ArrayList) { - std::cerr << " JNIEnv -> FindClass(ArrayList) failed" << std::endl; - return -1.0; - } - jmethodID mat_obj_num_rows = env->GetMethodID(CacheableData, "getNumRows", "()J"); - if (!mat_obj_num_rows) { - std::cerr << " JNIEnv -> GetMethodID() failed" << std::endl; - return -1.0; - } - jmethodID mat_obj_num_cols = env->GetMethodID(CacheableData, "getNumColumns", "()J"); - if (!mat_obj_num_cols) { - std::cerr << " JNIEnv -> GetMethodID() failed" << std::endl; - return -1.0; - } - jmethodID ArrayList_size = env->GetMethodID(ArrayList, "size", "()I"); - jmethodID ArrayList_get = env->GetMethodID(ArrayList, "get", "(I)Ljava/lang/Object;"); - - std::vector<Matrix<double>> in; - jint num_inputs = env->CallIntMethod(inputs_, ArrayList_size); -#ifdef _DEBUG - std::cout << "num inputs: " << num_inputs << " offsets: " << input_offset << std::endl; -#endif - - for(auto ptr_idx = 0, input_idx = 0; input_idx < input_offset; ptr_idx+=4, input_idx++) { - jobject input_obj = env->CallObjectMethod(inputs_, ArrayList_get, input_idx); - auto m = static_cast<uint32_t>(env->CallIntMethod(input_obj, mat_obj_num_rows)); - auto n = static_cast<uint32_t>(env->CallIntMethod(input_obj, mat_obj_num_cols)); - - in.push_back(Matrix<double>{reinterpret_cast<double *>(inputs[ptr_idx + 3]), - reinterpret_cast<uint32_t *>(inputs[ptr_idx + 1]), - reinterpret_cast<uint32_t *>(inputs[ptr_idx + 2]), - m, n, static_cast<uint32_t>(inputs[ptr_idx])}); -#ifdef _DEBUG - std::cout << "input #" << input_idx << " m=" << m << " n=" << n << std::endl; -#endif - } - - std::vector<Matrix<double>> side_inputs; - for(uint32_t ptr_idx = 0, input_idx = input_offset; input_idx < num_inputs; ptr_idx+=4, input_idx++) { - jobject side_input_obj = env->CallObjectMethod(inputs_, ArrayList_get, input_idx); - auto m = static_cast<uint32_t>(env->CallIntMethod(side_input_obj, mat_obj_num_rows)); - auto n = static_cast<uint32_t>(env->CallIntMethod(side_input_obj, mat_obj_num_cols)); - - side_inputs.push_back(Matrix<double>{reinterpret_cast<double *>(sides[ptr_idx + 3]), - reinterpret_cast<uint32_t *>(sides[ptr_idx + 1]), - reinterpret_cast<uint32_t *>(sides[ptr_idx + 2]), - m, n, static_cast<uint32_t>(sides[ptr_idx])}); - } - - std::unique_ptr<Matrix<double>> out; - if(out_obj != nullptr) { - out = std::make_unique<Matrix<double>>(Matrix<double>{reinterpret_cast<double*>(output[3]), - reinterpret_cast<uint32_t*>(output[1]), - reinterpret_cast<uint32_t*>(output[2]), - static_cast<uint32_t>(env->CallIntMethod(out_obj, mat_obj_num_rows)), - static_cast<uint32_t>(env->CallIntMethod(out_obj, mat_obj_num_cols)), - static_cast<uint32_t>(output[0])}); - } - - result = ctx_->execute_kernel(cstr_name, in, side_inputs, out.get(), scalars, - env->GetArrayLength(scalars_), grix); - - RELEASE_ARRAY(env, in_ptrs, inputs); - RELEASE_ARRAY(env, side_ptrs, sides); - RELEASE_ARRAY(env, out_ptrs, output); - RELEASE_ARRAY(env, scalars_, scalars); - - env->ReleaseStringUTFChars(name, cstr_name); - return result; - } - catch (std::exception& e) { - printStdException(env, name, e); - } - catch (...) { - printException(env, name); - } - return result; +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1d + (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, jlongArray in, jlongArray sides, jlongArray out, + jlong scalars) { + return launch_spoof_operator<double, SpoofCellwise<double>>(jenv, jobj, ctx, meta, in, sides, out, scalars); +} + +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f + (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, jlongArray in, jlongArray sides, jlongArray out, + jlong scalars) { + return launch_spoof_operator<float, SpoofRowwise<float>>(jenv, jobj, ctx, meta, in, sides, out, scalars); +} + +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1d + (JNIEnv *jenv, jclass jobj, jlong ctx, jlongArray meta, jlongArray in, jlongArray sides, jlongArray out, + jlong scalars) { + return launch_spoof_operator<double, SpoofRowwise<double>>(jenv, jobj, ctx, meta, in, sides, out, scalars); } diff --git a/src/main/cuda/spoof-launcher/jni_bridge.h b/src/main/cuda/spoof-launcher/jni_bridge.h index 48c4882..1e9ef20 100644 --- a/src/main/cuda/spoof-launcher/jni_bridge.h +++ b/src/main/cuda/spoof-launcher/jni_bridge.h @@ -34,61 +34,67 @@ extern "C" { * Method: initialize_cuda_context * Signature: (I)J */ -JNIEXPORT jlong JNICALL +[[maybe_unused]] JNIEXPORT jlong JNICALL Java_org_apache_sysds_hops_codegen_SpoofCompiler_initialize_1cuda_1context( - JNIEnv *, jobject, jint, jstring); + JNIEnv *, [[maybe_unused]] jobject, jint, jstring); /* * Class: org_apache_sysds_hops_codegen_SpoofCompiler * Method: destroy_cuda_context * Signature: (JI)V */ -JNIEXPORT void JNICALL +[[maybe_unused]] JNIEXPORT void JNICALL Java_org_apache_sysds_hops_codegen_SpoofCompiler_destroy_1cuda_1context( - JNIEnv *, jobject, jlong, jint); + [[maybe_unused]] JNIEnv *, [[maybe_unused]] jobject, jlong, jint); /* * Class: org_apache_sysds_hops_codegen_cplan_CNodeCell * Method: compile_nvrtc * Signature: (JLjava/lang/String;Ljava/lang/String;IIZ)I */ -JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc - (JNIEnv *, jobject, jlong, jstring, jstring, jint, jint, jboolean); +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeCell_compile_1nvrtc + (JNIEnv *, [[maybe_unused]] jobject, jlong, jstring, jstring, jint, jint, jboolean); /* * Class: org_apache_sysds_hops_codegen_cplan_CNodeRow * Method: compile_nvrtc * Signature: (JLjava/lang/String;Ljava/lang/String;IIIZ)I */ -JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc - (JNIEnv *, jobject, jlong, jstring, jstring, jint, jint, jint, jboolean); +[[maybe_unused]] [[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_hops_codegen_cplan_CNodeRow_compile_1nvrtc + (JNIEnv *, [[maybe_unused]] jobject, jlong, jstring, jstring, jint, jint, jint, jboolean); /* - * Class: org_apache_sysds_hops_codegen_SpoofCompiler - * Method: compile_cuda_kernel - * Signature: (JLjava/lang/String;Ljava/lang/String;)Z + * Class: org_apache_sysds_runtime_codegen_SpoofCUDACellwiseOperator + * Method: execute_f + * Signature: (J[J[J[J[JJ)I */ -JNIEXPORT jboolean JNICALL -Java_org_apache_sysds_hops_codegen_SpoofCompiler_compile_1cuda_1kernel( - JNIEnv *, jobject, jlong, jstring, jstring); +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1f + (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, jlongArray, jlong); /* - * Class: org_apache_sysds_runtime_instructions_gpu_SpoofCUDAInstruction + * Class: org_apache_sysds_runtime_codegen_SpoofCUDACellwiseOperator * Method: execute_d - * Signature: (...)Z + * Signature: (J[J[J[J[JJ)I + */ +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDACellwise_execute_1d + (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, jlongArray, jlong); + + +/* + * Class: org_apache_sysds_runtime_codegen_SpoofCUDARowwise + * Method: execute_f + * Signature: (J[J[J[J[JJ)I */ -JNIEXPORT jdouble JNICALL -Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1d( - JNIEnv *, jobject, jlong, jstring, jlongArray, jint, jlongArray, jlongArray, jdoubleArray, jlong, jobject, jobject); +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1f + (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, jlongArray, jlong); -///* -// * Class: org_apache_sysds_runtime_instructions_gpu_SpoofCUDAInstruction -// * Method: execute_f -// * Signature: (...)Z -// */ -//JNIEXPORT jfloat JNICALL -//Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1f( -// JNIEnv *, jobject, jlong, jstring, jlongArray, jlongArray, jlong, jfloatArray, jlong, jlong, jlong, jlong); +/* + * Class: org_apache_sysds_runtime_codegen_SpoofCUDARowwise + * Method: execute_d + * Signature: (J[J[J[J[JJ)I + */ +[[maybe_unused]] JNIEXPORT jint JNICALL Java_org_apache_sysds_runtime_codegen_SpoofCUDARowwise_execute_1d + (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray, jlongArray, jlong); #ifdef __cplusplus } diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index b9be432..335d60f 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -176,12 +176,12 @@ public class Types // these values need to match with their native counterparts (spoof cuda ops) public enum AggOp { - SUM(1), SUM_SQ(2), MIN(3), MAX(4), - PROD(5), SUM_PROD(6), - TRACE(7), MEAN(8), VAR(9), - MAXINDEX(10), MININDEX(11), - COUNT_DISTINCT(12), - COUNT_DISTINCT_APPROX(13); + SUM(0), SUM_SQ(1), MIN(2), MAX(3), + PROD(4), SUM_PROD(5), + TRACE(6), MEAN(7), VAR(8), + MAXINDEX(9), MININDEX(10), + COUNT_DISTINCT(11), + COUNT_DISTINCT_APPROX(12); @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java index 46a7042..28b722f 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java @@ -83,7 +83,7 @@ import org.apache.sysds.parser.WhileStatementBlock; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.codegen.CodegenUtils; -import org.apache.sysds.runtime.codegen.SpoofCUDA; +import org.apache.sysds.runtime.codegen.SpoofCUDAOperator; import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; @@ -242,6 +242,8 @@ public class SpoofCompiler { if(ctx_ptr != 0) { native_contexts.put(GeneratorAPI.CUDA, ctx_ptr); API = GeneratorAPI.CUDA; + org.apache.sysds.runtime.instructions.gpu.SpoofCUDAInstruction.resetFloatingPointPrecision(); + LOG.info("Successfully loaded spoof cuda library"); } else { @@ -508,14 +510,16 @@ public class SpoofCompiler { if( cla == null ) { String src_cuda = ""; String src = tmp.getValue().codegen(false, GeneratorAPI.JAVA); - cla = CodegenUtils.compileClass("codegen."+ tmp.getValue().getClassname(), src); + cla = CodegenUtils.compileClass("codegen." + tmp.getValue().getClassname(), src); if(API == GeneratorAPI.CUDA) { if(tmp.getValue().isSupported(API)) { src_cuda = tmp.getValue().codegen(false, GeneratorAPI.CUDA); int op_id = tmp.getValue().compile(API, src_cuda); - if(op_id >= 0) - CodegenUtils.putNativeOpData(new SpoofCUDA(src_cuda, tmp.getValue(), op_id)); + if(op_id >= 0) { + CodegenUtils.putCUDAOpID("codegen." + tmp.getValue().getClassname(), op_id); + CodegenUtils.putCUDASource(op_id, src_cuda); + } else { LOG.warn("CUDA compilation failed, falling back to JAVA"); tmp.getValue().setGeneratorAPI(GeneratorAPI.JAVA); @@ -549,23 +553,28 @@ public class SpoofCompiler { if( PLAN_CACHE_POLICY!=PlanCachePolicy.NONE ) planCache.putPlan(tmp.getValue(), cla); } - else if( DMLScript.STATISTICS ) { - Statistics.incrementCodegenOpCacheHits(); + else { + if( DMLScript.STATISTICS ) + Statistics.incrementCodegenOpCacheHits(); + if(CodegenUtils.getCUDAopID(cla.getName()) != null) { + tmp.getValue().setGeneratorAPI(GeneratorAPI.CUDA); + tmp.getValue().setVarName(cla.getName().split("\\.")[1]); + } } //make class available and maintain hits if(cla != null) { - if(CodegenUtils.getNativeOpData(cla.getName()) != null) { - if(tmp.getValue().getVarname() == null) { - tmp.getValue().setVarName(cla.getName()); - if(tmp.getValue().getGeneratorAPI() != CodegenUtils.getNativeOpData(cla.getName()) - .getCNodeTemplate().getGeneratorAPI()) - { - tmp.getValue().setGeneratorAPI(CodegenUtils.getNativeOpData(cla.getName()) - .getCNodeTemplate().getGeneratorAPI()); - } - } - } +// if(CodegenUtils.getNativeOpData(cla.getName()) != null) { +// if(tmp.getValue().getVarname() == null) { +// tmp.getValue().setVarName(cla.getName()); +// if(tmp.getValue().getGeneratorAPI() != CodegenUtils.getNativeOpData(cla.getName()) +// .getCNodeTemplate().getGeneratorAPI()) +// { +// tmp.getValue().setGeneratorAPI(CodegenUtils.getNativeOpData(cla.getName()) +// .getCNodeTemplate().getGeneratorAPI()); +// } +// } +// } clas.put(cplan.getKey(), new Pair<Hop[], Class<?>>(tmp.getKey(), cla)); } if( DMLScript.STATISTICS ) diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java index 9d1061f..1c67e3d 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java @@ -278,8 +278,7 @@ public class CNodeCell extends CNodeTpl public int compile(GeneratorAPI api, String src) { if(api == GeneratorAPI.CUDA) - return compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, _type.getValue - (), + return compile_nvrtc(SpoofCompiler.native_contexts.get(api), _genVar, src, _type.getValue(), _aggOp != null ? _aggOp.getValue() : 0, _sparseSafe); return -1; } diff --git a/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java b/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java index d8dfec5..8d8c134 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/CodegenUtils.java @@ -43,6 +43,7 @@ import javax.tools.ToolProvider; import org.apache.commons.io.IOUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.matrix.data.Pair; import org.codehaus.janino.SimpleCompiler; import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.codegen.SpoofCompiler; @@ -65,7 +66,8 @@ public class CodegenUtils //janino-specific map of source code transfer/recompile on-demand private static ConcurrentHashMap<String, String> _src = new ConcurrentHashMap<>(); - private static ConcurrentHashMap<String, SpoofCUDA> _native_op_data = new ConcurrentHashMap<>(); + private static ConcurrentHashMap<String, Integer> _CUDA_op_IDs = new ConcurrentHashMap<>(); + private static ConcurrentHashMap<Integer, String> _CUDA_op_src = new ConcurrentHashMap<>(); //javac-specific working directory for src/class files private static String _workingDir = null; @@ -160,14 +162,18 @@ public class CodegenUtils return ret; } - public static SpoofCUDA getNativeOpData(String name) { - return _native_op_data.get(name); + public static Integer getCUDAopID(String name) { + return _CUDA_op_IDs.get(name); } - public static void putNativeOpData(SpoofCUDA op) { - _native_op_data.put(op.getName(), op); + public static void putCUDAOpID(String name, int id) { + _CUDA_op_IDs.put(name, id); } - + + public static void putCUDASource(int id, String src) { + _CUDA_op_src.put(id, src); + } + public static SideInput createSideInput(MatrixBlock in) { SideInput ret = (in.isInSparseFormat() || !in.isAllocated()) ? new SideInput(null, in, in.getNumColumns()) : diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java deleted file mode 100644 index 88d91a7..0000000 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.runtime.codegen; - -import java.util.ArrayList; - -import org.apache.sysds.hops.codegen.SpoofCompiler; -import org.apache.sysds.hops.codegen.cplan.CNodeCell; -import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg; -import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct; -import org.apache.sysds.hops.codegen.cplan.CNodeRow; -import org.apache.sysds.hops.codegen.cplan.CNodeTpl; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.instructions.cp.ScalarObject; -import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; - -public class SpoofCUDA extends SpoofOperator { - private static final long serialVersionUID = -2161276866245388359L; - - private final CNodeTpl cnt; - public final String name; - public final String src; - public final int id; - - public SpoofCUDA(String source, CNodeTpl cnode, int _id) { - name = "codegen." + cnode.getVarname(); - cnt = cnode; - src = source; - id = _id; - } - - public String getName() { - return name; - } - - public CNodeTpl getCNodeTemplate() { - return cnt; - } - - public String getSpoofTemplateType() { - if (cnt instanceof CNodeCell) - return "CW"; - else if(cnt instanceof CNodeRow) - return "RA"; - else if(cnt instanceof CNodeMultiAgg) - return "MA"; - else if(cnt instanceof CNodeOuterProduct) - return "OP"; - else - throw new RuntimeException("unknown spoof operator type"); - } - @Override - public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) { - throw new RuntimeException("method not implemented for SpoofNativeCUDA"); - } - - public double execute(ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects, MatrixObject out_obj, - ExecutionContext ec, boolean sparseOut) { - double ret; - long[] out_ptr = {0,0,0,0}; - - if(out_obj != null) { - if(sparseOut) { - out_ptr[0] = ec.getGPUSparsePointerAddress(out_obj).nnz; - out_ptr[1] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(out_obj).rowPtr); - out_ptr[2] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(out_obj).colInd); - out_ptr[3] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(out_obj).val); - } - else - out_ptr[3] = ec.getGPUDensePointerAddress(out_obj); - } - - int offset = 1; - if(cnt instanceof CNodeOuterProduct) - offset = 2; - - // only dense input preparation for now - long[] in_ptrs = new long[offset * 4]; - for(int i = 0; i < offset; i += 4) { - if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) { - in_ptrs[i * 4] = ec.getGPUSparsePointerAddress(inputs.get(i)).nnz; - in_ptrs[i * 4 + 1] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr); - in_ptrs[i * 4 + 2] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd); - in_ptrs[i * 4 + 3] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val); - } - else - in_ptrs[i * 4 + 3] = ec.getGPUDensePointerAddress(inputs.get(i)); - } - - long[] side_ptrs = new long[(inputs.size() - offset) * 4]; - for(int i = offset; i < inputs.size(); i++) { - int j = (i - offset) * 4; - if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) { - side_ptrs[j] = ec.getGPUSparsePointerAddress(inputs.get(i)).nnz; - side_ptrs[j + 1] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr); - side_ptrs[j + 2] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd); - side_ptrs[j + 3] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val); - } - else - side_ptrs[j + 3] = ec.getGPUDensePointerAddress(inputs.get(i)); - } - -// if(isSinglePrecision()) { -// float[] scalars = prepInputScalarsFloat(scalarObjects); -// -// // ToDo: handle float -// ret = execute_f(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), name.split("\\.")[1], -// in_ptrs, side_ptrs, out_ptr, scalars, inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), out_obj.getNumColumns(),0); -// -// } -// else { - double[] scalars = prepInputScalars(scalarObjects); - - ret = execute_d(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), name.split("\\.")[1], - in_ptrs, offset, side_ptrs, out_ptr, scalars,0, inputs, out_obj); -// } - return ret; - } - - @Override - public String getSpoofType() { - String[] tmp = getClass().getName().split("\\."); - return tmp[tmp.length-1] + "_" + getSpoofTemplateType() + "_" + name.split("\\.")[1]; - } - - private native float execute_f(long ctx, String name, long[] in_ptr, long[] side_ptr, - long out_ptr, float[] scalars, long m, long n, long out_len, long grix); - - private native double execute_d(long ctx, String name, long[] in_ptr, int offset, long[] side_ptr, long[] out_ptr, - double[] scalars, long grix, ArrayList<MatrixObject> inputs, MatrixObject output); -} diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java new file mode 100644 index 0000000..6963c2d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDACellwise.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.codegen; + +import jcuda.Pointer; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; +import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA; + +import java.util.ArrayList; + +public class SpoofCUDACellwise extends SpoofCellwise implements SpoofCUDAOperator { + private static final Log LOG = LogFactory.getLog(SpoofCUDACellwise.class.getName()); + private final int ID; + private final PrecisionProxy call; + private Pointer ptr; + private final SpoofCellwise fallback_java_op; + + public SpoofCUDACellwise(CellType type, boolean sparseSafe, boolean containsSeq, AggOp aggOp, int id, + PrecisionProxy ep, SpoofCellwise fallback) { + super(type, sparseSafe, containsSeq, aggOp); + ID = id; + call = ep; + ptr = null; + fallback_java_op = fallback; + } + + @Override + public ScalarObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects) { + double[] result = new double[1]; + // ToDo: this is a temporary "solution" before perf opt + int NT=256; + long N = inputs.get(0).getNumRows() * inputs.get(0).getNumColumns(); + long num_blocks = ((N + NT * 2 - 1) / (NT * 2)); + Pointer ptr = ec.getGPUContext(0).allocate(getName(), LibMatrixCUDA.sizeOfDataType * num_blocks); + long[] out = {1,1,1, 0, 0, GPUObject.getPointerAddress(ptr)}; + int offset = 1; + if(call.exec(ec, this, ID, prepareInputPointers(ec, inputs, offset), + prepareSideInputPointers(ec, inputs, offset, false), out, scalarObjects, 0 ) != 0) { + LOG.error("SpoofCUDA " + getSpoofType() + " operator failed to execute. Trying Java fallback.\n"); + // ToDo: java fallback + } + + LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, result, getName(), false); + + return new DoubleObject(result[0]); + } + + @Override public String getName() { + return getSpoofType(); + } + + @Override public void setScalarPtr(Pointer _ptr) { + ptr = _ptr; + } + + @Override public Pointer getScalarPtr() { + return ptr; + } + + @Override public void releaseScalarGPUMemory(ExecutionContext ec) { + if(ptr != null) { + ec.getGPUContext(0).cudaFreeHelper(getSpoofType(), ptr, DMLScript.EAGER_CUDA_FREE); + ptr = null; + } + } + + @Override + public MatrixObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, ArrayList<ScalarObject> scalarObjects, + String outputName) { + + long out_rows = ec.getMatrixObject(outputName).getNumRows(); + long out_cols = ec.getMatrixObject(outputName).getNumColumns(); + MatrixObject a = inputs.get(0); + GPUContext gctx = ec.getGPUContext(0); + int m = (int) a.getNumRows(); + int n = (int) a.getNumColumns(); + double[] scalars = prepInputScalars(scalarObjects); + if(_type == CellType.COL_AGG) + out_rows = 1; + else if(_type == SpoofCellwise.CellType.ROW_AGG) + out_cols = 1; + + boolean sparseSafe = isSparseSafe() || ((inputs.size() < 2) && + genexec( 0, new SideInput[0], scalars, m, n, 0, 0 ) == 0); + +// ec.setMetaData(outputName, out_rows, out_cols); + GPUObject g = a.getGPUObject(gctx); + boolean sparseOut = _type == CellType.NO_AGG && sparseSafe && g.isSparse();; + + long nnz = g.getNnz("spoofCUDA" + getSpoofType(), false); + if(sparseOut) + LOG.warn("sparse out"); + MatrixObject out_obj = sparseOut ? + (ec.getSparseMatrixOutputForGPUInstruction(outputName, out_rows, out_cols, (isSparseSafe() && nnz > 0) ? + nnz : out_rows * out_cols).getKey()) : + (ec.getDenseMatrixOutputForGPUInstruction(outputName, out_rows, out_cols).getKey()); + + int offset = 1; + if(!(inputIsEmpty(a.getGPUObject(gctx)) && sparseSafe )) { + if(call.exec(ec, this, ID, prepareInputPointers(ec, inputs, offset), prepareSideInputPointers(ec, inputs, offset, false), + prepareOutputPointers(ec, out_obj, sparseOut), scalarObjects, 0) != 0) { + LOG.error("SpoofCUDA " + getSpoofType() + " operator failed to execute. Trying Java fallback.(ToDo)\n"); + // ToDo: java fallback + } + } + return out_obj; + } + + private boolean inputIsEmpty(GPUObject g) { + if(g.getDensePointer() != null || g.getSparseMatrixCudaPointer() != null) + return true; + return false; + } + + // used to determine sparse safety + @Override + protected double genexec(double a, SideInput[] b, double[] scalars, int m, int n, long gix, int rix, int cix) { + return fallback_java_op.genexec(a, b, scalars, m, n, 0, 0, 0); + } + + public int execute_sp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars) { + return execute_f(ctx, meta, in, sides, out, scalars); + } + + public int execute_dp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars) { + return execute_d(ctx, meta, in, sides, out, scalars); + } + + public static native int execute_f(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars); + public static native int execute_d(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars); +} diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java new file mode 100644 index 0000000..cec3b70 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDAOperator.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.codegen; + +import java.util.ArrayList; + +import jcuda.Pointer; +import org.apache.sysds.hops.codegen.SpoofCompiler; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; +import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA; +import org.json4s.ParserUtil; + +import static org.apache.sysds.runtime.matrix.data.LibMatrixCUDA.sizeOfDataType; + +public interface SpoofCUDAOperator { + int JNI_MAT_ENTRY_SIZE = 6; + abstract class PrecisionProxy { + protected final long ctx; + + public PrecisionProxy() { ctx = SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA); } + + public abstract int exec(ExecutionContext ec, SpoofCUDAOperator op, int opID, long[] in, long[] sides, long[] out, + ArrayList<ScalarObject> scalarObjects, long grix); + + protected Pointer transferScalars(ExecutionContext ec, SpoofCUDAOperator op, int sizeOfDataType, + ArrayList<ScalarObject> scalarObjects) { + double[] s = SpoofOperator.prepInputScalars(scalarObjects); + Pointer ptr = ec.getGPUContext(0).allocate(op.getName(), (long) scalarObjects.size() * sizeOfDataType); + LibMatrixCUDA.cudaSupportFunctions.hostToDevice(ec.getGPUContext(0), s, ptr, op.getName()); + return ptr; + } + } + + String getName(); + + void setScalarPtr(Pointer ptr); + + Pointer getScalarPtr(); + + void releaseScalarGPUMemory(ExecutionContext ec); + + default long [] prepareInputPointers(ExecutionContext ec, ArrayList<MatrixObject> inputs, int offset) { + long [] in = new long[offset * JNI_MAT_ENTRY_SIZE]; + for(int i = 0; i < offset; i++) { + int j = i * JNI_MAT_ENTRY_SIZE; + + if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) { + in[j] = ec.getGPUSparsePointerAddress(inputs.get(i)).nnz; + in[j + 1] = inputs.get(i).getNumRows(); + in[j + 2] = inputs.get(i).getNumColumns(); + in[j + 3] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr); + in[j + 4] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd); + in[j + 5] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val); + } + else { + in[j] = inputs.get(i).getNnz(); + in[j + 1] = inputs.get(i).getNumRows(); + in[j + 2] = inputs.get(i).getNumColumns(); + in[j + 5] = ec.getGPUDensePointerAddress(inputs.get(i)); + } + } + return in; + } + + default long [] prepareSideInputPointers(ExecutionContext ec, ArrayList<MatrixObject> inputs, int offset, boolean tB1) { + long[] sides = new long[(inputs.size() - offset) * JNI_MAT_ENTRY_SIZE]; + for(int i = offset; i < inputs.size(); i++) { + int j = (i - offset) * JNI_MAT_ENTRY_SIZE; + if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) { + sides[j] = ec.getGPUSparsePointerAddress(inputs.get(i)).nnz; + sides[j + 1] = inputs.get(i).getNumRows(); + sides[j + 2] = inputs.get(i).getNumColumns(); + sides[j + 3] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr); + sides[j + 4] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd); + sides[j + 5] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val); + } + else { + if(tB1 && j == 0) { + long rows = inputs.get(i).getNumRows(); + long cols = inputs.get(i).getNumColumns(); + Pointer b1 = inputs.get(i).getGPUObject(ec.getGPUContext(0)).getDensePointer(); + Pointer ptr = ec.getGPUContext(0).allocate(getName(), rows * cols * sizeOfDataType); + +// double[] tmp1 = new double[(int) (rows * cols)]; +// LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), b1, tmp1, getName(), false); +// +// System.out.println("Mat before transpose: rows=" + rows + " cols=" + cols + "\n"); +// for(int m = 0; m < rows; m++) { +// StringBuilder sb = new StringBuilder(); +// for(int n = 0; n < cols; n++) +// sb.append(" " + tmp1[(int) (cols * m + n)]); +// System.out.println(sb.toString()); +// } + + LibMatrixCUDA.denseTranspose(ec, ec.getGPUContext(0), getName(), + b1, ptr, rows, cols); + +// double[] tmp2 = new double[(int) (rows * cols)]; +// LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, tmp2, getName(), false); +// +// System.out.println("Mat after transpose: rows=" + cols + " cols=" + rows + "\n"); +// for(int m = 0; m < cols; m++) { +// StringBuilder sb = new StringBuilder(); +// for(int n = 0; n < rows; n++) +// sb.append(" " + tmp2[(int) (rows * m + n)]); +// System.out.println(sb.toString()); +// } + + sides[j] = inputs.get(i).getNnz(); + sides[j + 1] = cols; + sides[j + 2] = rows; + sides[j + 5] = GPUObject.getPointerAddress(ptr); + + } else { + sides[j] = inputs.get(i).getNnz(); + sides[j + 1] = inputs.get(i).getNumRows(); + sides[j + 2] = inputs.get(i).getNumColumns(); + sides[j + 5] = ec.getGPUDensePointerAddress(inputs.get(i)); + } + } + } + return sides; + } + + default long[] prepareOutputPointers(ExecutionContext ec, MatrixObject output, boolean sparseOut) { + long[] out = {0,0,0,0,0,0}; + + if(sparseOut) { + out[0] = ec.getGPUSparsePointerAddress(output).nnz; + out[1] = output.getNumRows(); + out[2] = output.getNumColumns(); + out[3] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(output).rowPtr); + out[4] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(output).colInd); + out[5] = GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(output).val); + } + else { + out[0] = output.getNnz(); + out[1] = output.getNumRows(); + out[2] = output.getNumColumns(); + out[5] = ec.getGPUDensePointerAddress(output); + } + return out; + } + + MatrixObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, + ArrayList<ScalarObject> scalarObjects, String outputName); + + ScalarObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, + ArrayList<ScalarObject> scalarObjects); + + int execute_sp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars); + int execute_dp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars); +} diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java new file mode 100644 index 0000000..2f1b537 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDARowwise.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.codegen; + +import jcuda.Pointer; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; +import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA; + +import java.util.ArrayList; + +public class SpoofCUDARowwise extends SpoofRowwise implements SpoofCUDAOperator { + private static final Log LOG = LogFactory.getLog(SpoofCUDARowwise.class.getName()); + private final int ID; + private final PrecisionProxy call; + private Pointer ptr; + + public SpoofCUDARowwise(RowType type, long constDim2, boolean tB1, int reqVectMem, int id, + PrecisionProxy ep) { + super(type, constDim2, tB1, reqVectMem); + ID = id; + call = ep; + ptr = null; + } + + @Override public String getName() { + return getSpoofType(); + } + + @Override public void setScalarPtr(Pointer _ptr) { + ptr = _ptr; + } + + @Override public Pointer getScalarPtr() { + return ptr; + } + + @Override public void releaseScalarGPUMemory(ExecutionContext ec) { + if(ptr != null) { + ec.getGPUContext(0).cudaFreeHelper(getSpoofType(), ptr, DMLScript.EAGER_CUDA_FREE); + ptr = null; + } + } + + @Override + public ScalarObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, + ArrayList<ScalarObject> scalarObjects) { + double[] result = new double[1]; + Pointer ptr = ec.getGPUContext(0).allocate(getName(), LibMatrixCUDA.sizeOfDataType); + long[] out = {1,1,1, 0, 0, GPUObject.getPointerAddress(ptr)}; + int offset = 1; + if(call.exec(ec, this, ID, prepareInputPointers(ec, inputs, offset), prepareSideInputPointers(ec, inputs, offset, _tB1), + out, scalarObjects, 0) != 0) { + LOG.error("SpoofCUDA " + getSpoofType() + " operator failed to execute. Trying Java fallback.\n"); + // ToDo: java fallback + } + LibMatrixCUDA.cudaSupportFunctions.deviceToHost(ec.getGPUContext(0), ptr, result, getName(), false); + return new DoubleObject(result[0]); + } + + @Override + public MatrixObject execute(ExecutionContext ec, ArrayList<MatrixObject> inputs, + ArrayList<ScalarObject> scalarObjects, String outputName) { + + int m = (int) inputs.get(0).getNumRows(); + int n = (int) inputs.get(0).getNumColumns(); + final int n2 = _type.isConstDim2(_constDim2) ? (int)_constDim2 : _type.isRowTypeB1() || + hasMatrixObjectSideInput(inputs) ? getMinColsMatrixObjectSideInputs(inputs) : -1; + OutputDimensions out_dims = new OutputDimensions(m, n, n2); + ec.setMetaData(outputName, out_dims.rows, out_dims.cols); + MatrixObject out_obj = ec.getDenseMatrixOutputForGPUInstruction(outputName, out_dims.rows, out_dims.cols).getKey(); + + int offset = 1; + if(call.exec(ec,this, ID, prepareInputPointers(ec, inputs, offset), prepareSideInputPointers(ec, inputs, + offset, _tB1), prepareOutputPointers(ec, out_obj, false), scalarObjects, 0) != 0) { + LOG.error("SpoofCUDA " + getSpoofType() + " operator failed to execute. Trying Java fallback.\n"); + // ToDo: java fallback + } + return out_obj; + } + + // unused + @Override protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, + long grix, int rix) { } + + // unused + @Override protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, + int ci, int alen, int n, long grix, int rix) { } + + public int execute_sp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars) { + return execute_f(ctx, meta, in, sides, out, scalars); + } + + public int execute_dp(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars) { + return execute_d(ctx, meta, in, sides, out, scalars); + } + + public static native int execute_f(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars); + public static native int execute_d(long ctx, long[] meta, long[] in, long[] sides, long[] out, long scalars); +} diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java index 89b502e..e84ed3a 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java @@ -52,10 +52,10 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl // these values need to match with their native counterparts (spoof cuda ops) public enum CellType { - NO_AGG(1), - FULL_AGG(2), - ROW_AGG(3), - COL_AGG(4); + NO_AGG(0), + FULL_AGG(1), + ROW_AGG(2), + COL_AGG(3); private final int value; private final static HashMap<Integer, CellType> map = new HashMap<>(); @@ -87,7 +87,7 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl MAX, } - private final CellType _type; + protected final CellType _type; private final AggOp _aggOp; private final boolean _sparseSafe; private final boolean _containsSeq; @@ -115,6 +115,10 @@ public abstract class SpoofCellwise extends SpoofOperator implements Serializabl return _containsSeq; } + @Override public SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep) { + return new SpoofCUDACellwise(_type, _sparseSafe, _containsSeq, _aggOp, opID, ep, this); + } + @Override public String getSpoofType() { return "Cell" + getClass().getName().split("\\.")[1]; diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofMultiAggregate.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofMultiAggregate.java index 2b0eb1f..514b21e 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofMultiAggregate.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofMultiAggregate.java @@ -68,6 +68,11 @@ public abstract class SpoofMultiAggregate extends SpoofOperator implements Seria return "MA" + getClass().getName().split("\\.")[1]; } + @Override public SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep) { + // ToDo: SpoofCUDAMultiAggregate + return null; + } + @Override public MatrixBlock execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, MatrixBlock out) { return execute(inputs, scalarObjects, out, 1, 0); diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java index 3088e84..1ea229e 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofOperator.java @@ -146,13 +146,6 @@ public abstract class SpoofOperator implements Serializable return scalars; } - protected static float[] prepInputScalarsFloat(ArrayList<ScalarObject> scalarObjects) { - float[] scalars = new float[scalarObjects.size()]; - for(int i=0; i < scalarObjects.size(); i++) - scalars[i] = (float)scalarObjects.get(i).getDoubleValue(); - return scalars; - } - public static long getTotalInputNnz(ArrayList<MatrixBlock> inputs) { return inputs.stream().mapToLong(in -> in.getNonZeros()).sum(); } @@ -231,6 +224,10 @@ public abstract class SpoofOperator implements Serializable return c; } + + + public abstract SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep); + public static class SideInput { public final DenseBlock ddat; public final MatrixBlock mdat; diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofOuterProduct.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofOuterProduct.java index 4b65ca2..6430788 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofOuterProduct.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofOuterProduct.java @@ -63,6 +63,11 @@ public abstract class SpoofOuterProduct extends SpoofOperator return _outerProductType; } + @Override public SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep) { + // ToDo: SpoofCUDAOuterProduct + return null; + } + @Override public String getSpoofType() { return "OP" + getClass().getName().split("\\.")[1]; diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java index bd37aaf..f8983e1 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java @@ -28,6 +28,7 @@ import java.util.concurrent.Future; import java.util.stream.IntStream; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFactory; import org.apache.sysds.runtime.data.SparseBlock; @@ -50,17 +51,17 @@ public abstract class SpoofRowwise extends SpoofOperator // Thanks to https://codingexplained.com/coding/java/enum-to-integer-and-integer-to-enum // these values need to match with their native counterparts (spoof cuda ops) public enum RowType { - NO_AGG(1), //no aggregation - NO_AGG_B1(2), //no aggregation w/ matrix mult B1 - NO_AGG_CONST(3), //no aggregation w/ expansion/contraction - FULL_AGG(4), //full row/col aggregation - ROW_AGG(5), //row aggregation (e.g., rowSums() or X %*% v) - COL_AGG (6), //col aggregation (e.g., colSums() or t(y) %*% X) - COL_AGG_T(7), //transposed col aggregation (e.g., t(X) %*% y) - COL_AGG_B1(8), //col aggregation w/ matrix mult B1 - COL_AGG_B1_T(9), //transposed col aggregation w/ matrix mult B1 - COL_AGG_B1R(10), //col aggregation w/ matrix mult B1 to row vector - COL_AGG_CONST (11);//col aggregation w/ expansion/contraction + NO_AGG(0), //no aggregation + NO_AGG_B1(1), //no aggregation w/ matrix mult B1 + NO_AGG_CONST(2), //no aggregation w/ expansion/contraction + FULL_AGG(3), //full row/col aggregation + ROW_AGG(4), //row aggregation (e.g., rowSums() or X %*% v) + COL_AGG (5), //col aggregation (e.g., colSums() or t(y) %*% X) + COL_AGG_T(6), //transposed col aggregation (e.g., t(X) %*% y) + COL_AGG_B1(7), //col aggregation w/ matrix mult B1 + COL_AGG_B1_T(8), //transposed col aggregation w/ matrix mult B1 + COL_AGG_B1R(9), //col aggregation w/ matrix mult B1 to row vector + COL_AGG_CONST (10);//col aggregation w/ expansion/contraction private final int value; private final static HashMap<Integer, RowType> map = new HashMap<>(); @@ -130,6 +131,10 @@ public abstract class SpoofRowwise extends SpoofOperator return "RA" + getClass().getName().split("\\.")[1]; } + @Override public SpoofCUDAOperator createCUDAInstrcution(Integer opID, SpoofCUDAOperator.PrecisionProxy ep) { + return new SpoofCUDARowwise(_type, _constDim2, _tB1, _reqVectMem, opID, ep); + } + @Override public ScalarObject execute(ArrayList<MatrixBlock> inputs, ArrayList<ScalarObject> scalarObjects, int k) { MatrixBlock out = ( k > 1 ) ? @@ -272,7 +277,7 @@ public abstract class SpoofRowwise extends SpoofOperator .anyMatch(in -> in.getNumColumns()>1); } - private static int getMinColsMatrixSideInputs(ArrayList<MatrixBlock> inputs) { + protected static int getMinColsMatrixSideInputs(ArrayList<MatrixBlock> inputs) { //For B1 types, get the output number of columns as the minimum //number of columns of side input matrices other than vectors. return IntStream.range(1, inputs.size()) @@ -280,20 +285,45 @@ public abstract class SpoofRowwise extends SpoofOperator .filter(ncol -> ncol > 1).min().orElse(1); } - private void allocateOutputMatrix(int m, int n, int n2, MatrixBlock out) { - switch( _type ) { - case NO_AGG: out.reset(m, n, false); break; - case NO_AGG_B1: out.reset(m, n2, false); break; - case NO_AGG_CONST: out.reset(m, (int)_constDim2, false); break; - case FULL_AGG: out.reset(1, 1, false); break; - case ROW_AGG: out.reset(m, 1, false); break; - case COL_AGG: out.reset(1, n, false); break; - case COL_AGG_T: out.reset(n, 1, false); break; - case COL_AGG_B1: out.reset(n2, n, false); break; - case COL_AGG_B1_T: out.reset(n, n2, false); break; - case COL_AGG_B1R: out.reset(1, n2, false); break; - case COL_AGG_CONST: out.reset(1, (int)_constDim2, false); break; + public static boolean hasMatrixObjectSideInput(ArrayList<MatrixObject> inputs) { + return IntStream.range(1, inputs.size()) + .mapToObj(i -> inputs.get(i)) + .anyMatch(in -> in.getNumColumns()>1); + } + + protected static int getMinColsMatrixObjectSideInputs(ArrayList<MatrixObject> inputs) { + //For B1 types, get the output number of columns as the minimum + //number of columns of side input matrices other than vectors. + return IntStream.range(1, inputs.size()) + .map(i -> (int) inputs.get(i).getNumColumns()) + .filter(ncol -> ncol > 1).min().orElse(1); + } + + protected class OutputDimensions { + public final int rows; + public final int cols; + OutputDimensions(int m, int n, int n2) { + switch(_type) { + case NO_AGG: rows = m; cols = n; break; + case NO_AGG_B1: rows = m; cols = n2; break; + case NO_AGG_CONST: rows = m; cols = (int) SpoofRowwise.this._constDim2; break; + case FULL_AGG: rows = 1; cols = 1; break; + case ROW_AGG: rows = m; cols = 1; break; + case COL_AGG: rows = 1; cols = n; break; + case COL_AGG_T: rows = n; cols = 1; break; + case COL_AGG_B1: rows = n2; cols = n; break; + case COL_AGG_B1_T: rows = n; cols = n2; break; + case COL_AGG_B1R: rows = 1; cols = n2; break; + case COL_AGG_CONST: rows = 1; cols = (int) SpoofRowwise.this._constDim2; break; + default: rows = 0; cols = 0; + } } + }; + + + private void allocateOutputMatrix(int m, int n, int n2, MatrixBlock out) { + OutputDimensions dims = new OutputDimensions(m, n, n2); + out.reset(dims.rows, dims.cols, false); out.allocateDenseBlock(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java index 6f3a1f7..35127f4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java @@ -19,13 +19,16 @@ package org.apache.sysds.runtime.instructions.gpu; +import jcuda.Sizeof; import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; -import org.apache.sysds.hops.codegen.cplan.CNodeCell; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.codegen.CodegenUtils; -import org.apache.sysds.runtime.codegen.SpoofCellwise; import org.apache.sysds.runtime.codegen.SpoofOperator; -import org.apache.sysds.runtime.codegen.SpoofCUDA; +import org.apache.sysds.runtime.codegen.SpoofCUDAOperator; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -35,23 +38,55 @@ import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.lineage.LineageTraceable; -import org.apache.sysds.runtime.instructions.cp.DoubleObject; import java.util.ArrayList; public class SpoofCUDAInstruction extends GPUInstruction implements LineageTraceable { - private final SpoofCUDA _op; + private static final Log LOG = LogFactory.getLog(SpoofCUDAInstruction.class.getName()); + + public static SpoofCUDAOperator.PrecisionProxy proxy = null; + + private final SpoofCUDAOperator _op; private final CPOperand[] _in; - public final CPOperand _out; - - private SpoofCUDAInstruction(SpoofOperator op, CPOperand[] in, CPOperand out, String opcode, String istr) { + + public static class SinglePrecision extends SpoofCUDAOperator.PrecisionProxy { + public int exec(ExecutionContext ec, SpoofCUDAOperator op, int opID, long[] in, long[] sides, long[] out, + ArrayList<ScalarObject> scalarObjects, long grix) { + op.setScalarPtr(transferScalars(ec, op, Sizeof.FLOAT, scalarObjects)); + long[] _metadata = { opID, grix, in.length, sides.length, out.length, scalarObjects.size() }; + return op.execute_sp(ctx, _metadata, in, sides, out, GPUObject.getPointerAddress(op.getScalarPtr())); + } + } + + public static class DoublePrecision extends SpoofCUDAOperator.PrecisionProxy { + public int exec(ExecutionContext ec, SpoofCUDAOperator op, int opID, long[] in, long[] sides, long[] out, + ArrayList<ScalarObject> scalarObjects, long grix) { + if(!scalarObjects.isEmpty()) + op.setScalarPtr(transferScalars(ec, op, Sizeof.DOUBLE, scalarObjects)); + long[] _metadata = { opID, grix, in.length, sides.length, out.length, scalarObjects.size() }; + return op.execute_dp(ctx, _metadata, in, sides, out, GPUObject.getPointerAddress(op.getScalarPtr())); + } + } + + /** + * Sets the internal state based on the DMLScript.DATA_TYPE + */ + public static void resetFloatingPointPrecision() { + if(DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("single")) { + SpoofCUDAInstruction.proxy = new SinglePrecision(); + } + else if(DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("double")) { + SpoofCUDAInstruction.proxy = new DoublePrecision(); + } + else { + throw new DMLRuntimeException("Unsupported floating point precision: " + DMLScript.FLOATING_POINT_PRECISION); + } + } + + private SpoofCUDAInstruction(SpoofCUDAOperator op, CPOperand[] in, CPOperand out, String opcode, String istr) { super(null, opcode, istr); - - if(!(op instanceof SpoofCUDA)) - throw new RuntimeException("SpoofGPUInstruction needs an operator of type SpoofNativeCUDA!"); - - _op = (SpoofCUDA) op; + _op = op; _in = in; _out = out; instString = istr; @@ -59,11 +94,18 @@ public class SpoofCUDAInstruction extends GPUInstruction implements LineageTrace } public static SpoofCUDAInstruction parseInstruction(String str) { + if(proxy == null) + throw new RuntimeException("SpoofCUDA Executor has not been initialized"); + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); ArrayList<CPOperand> inlist = new ArrayList<>(); - SpoofCUDA op = CodegenUtils.getNativeOpData(parts[2]); - String opcode = op.getSpoofType(); +// Integer op_id = CodegenUtils.getCUDAopID(parts[2].split("\\.")[1]); + Integer op_id = CodegenUtils.getCUDAopID(parts[2]); + Class<?> cla = CodegenUtils.getClass(parts[2]); + SpoofOperator fallback_java_op = CodegenUtils.createInstance(cla); + SpoofCUDAOperator op = fallback_java_op.createCUDAInstrcution(op_id, proxy); + String opcode = parts[0] + "CUDA" + fallback_java_op.getSpoofType(); for( int i=3; i<parts.length-2; i++ ) inlist.add(new CPOperand(parts[i])); @@ -87,51 +129,25 @@ public class SpoofCUDAInstruction extends GPUInstruction implements LineageTrace } } - // set the output dimensions to the hop node matrix dimensions - if( _out.getDataType() == Types.DataType.MATRIX) { - long out_rows = ec.getMatrixObject(_out.getName()).getNumRows(); - long out_cols = ec.getMatrixObject(_out.getName()).getNumColumns(); - - if(_op.getSpoofTemplateType().contains("CW")) - if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == SpoofCellwise.CellType.COL_AGG) - out_rows = 1; - else if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == SpoofCellwise.CellType.ROW_AGG) - out_cols = 1; - - - if(_op.getSpoofTemplateType().contains("RA")) { - // ToDo: make this workaround proper!! - boolean isConstDim2 = false; - int pos = _op.src.indexOf("// ConstDim2: "); - String strDim2 = _op.src.substring(pos + 14, _op.src.indexOf(System.lineSeparator(), pos)); - int dim2 = Integer.parseInt(strDim2); - if(dim2 > 0) - isConstDim2 = true; - - long n = inputs.get(0).getNumColumns(); - long n2 = isConstDim2 ? dim2 : inputs.get(0).getNumRows(); - if(_op.src.contains("COL_AGG_B1_T")) { - out_rows = n; - out_cols = n2; - } - + try { + // set the output dimensions to the hop node matrix dimensions + if(_out.getDataType() == Types.DataType.MATRIX) { + _op.execute(ec, inputs, scalars, _out.getName()); + ec.releaseMatrixOutputForGPUInstruction(_out.getName()); + } + else if(_out.getDataType() == Types.DataType.SCALAR) { + ScalarObject out = _op.execute(ec, inputs, scalars); + ec.setScalarOutput(_out.getName(), out); } - ec.setMetaData(_out.getName(), out_rows, out_cols); - GPUObject g = inputs.get(0).getGPUObject(ec.getGPUContext(0)); - boolean sparseOut = g.isSparse() && _op.getSpoofTemplateType().contains("CW"); - MatrixObject out_obj = sparseOut ? - (ec.getSparseMatrixOutputForGPUInstruction(_out.getName(), out_rows, out_cols, g.getNnz(getOpcode(), false)).getKey()) : - (ec.getDenseMatrixOutputForGPUInstruction(_out.getName(), out_rows, out_cols).getKey()); -// ec.setMetaData(_out.getName(), out_obj.getNumRows(), out_obj.getNumColumns()); - _op.execute(inputs, scalars, out_obj, ec, sparseOut); - ec.releaseMatrixOutputForGPUInstruction(_out.getName()); + _op.releaseScalarGPUMemory(ec); } - else if (_out.getDataType() == Types.DataType.SCALAR) { - ScalarObject out = new DoubleObject(_op.execute(inputs, scalars, null, ec, false)); - ec.setScalarOutput(_out.getName(), out); + catch(Exception ex) { + LOG.error("SpoofCUDAInstruction: " + _op.getName() + " operator failed to execute. Trying Java fallback.(ToDo)\n"); + + throw new DMLRuntimeException(ex); } - + for (CPOperand input : _in) if(input.getDataType()== Types.DataType.MATRIX) ec.releaseMatrixInputForGPUInstruction(input.getName()); @@ -139,7 +155,6 @@ public class SpoofCUDAInstruction extends GPUInstruction implements LineageTrace @Override public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { - return Pair.of(_out.getName(), - new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, _in))); + return Pair.of(_out.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, _in))); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java index f59ab95..1245502 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java @@ -1027,5 +1027,5 @@ public class GPUObject { } public static long getPointerAddress(Pointer p) { - return getPointerAddressInternal(p); + return (p == null) ? 0 : getPointerAddressInternal(p); }} diff --git a/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java index 10ea91e..b9a6917 100644 --- a/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java @@ -501,7 +501,7 @@ public class CellwiseTmplTest extends AutomatedTestBase if( !(rewrites && (testname.equals(TEST_NAME2) || testname.equals(TEST_NAME19))) && !testname.equals(TEST_NAME27) ) Assert.assertTrue(heavyHittersContainsSubString( - "spoofCell", "sp_spoofCell", "spoofMA", "sp_spoofMA", "gpu_SpoofCUDA_CW_")); + "spoofCell", "sp_spoofCell", "spoofMA", "sp_spoofMA", "gpu_spoofCUDACell")); if( testname.equals(TEST_NAME7) ) //ensure matrix mult is fused Assert.assertTrue(!heavyHittersContainsSubString("tsmm")); else if( testname.equals(TEST_NAME10) ) //ensure min/max is fused diff --git a/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java index b1ab778..fb6cae1 100644 --- a/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java @@ -819,7 +819,7 @@ public class RowAggTmplTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); Assert.assertTrue(heavyHittersContainsSubString("spoofRA") || heavyHittersContainsSubString("sp_spoofRA") - || heavyHittersContainsSubString("gpu_SpoofCUDA_RA")); + || heavyHittersContainsSubString("gpu_spoofCUDARA")); //ensure full aggregates for certain patterns if( testname.equals(TEST_NAME15) )
