This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 0084f24dc02c325a6a6eb2fe5c823f97e0db5359
Author: Mark Dokter <[email protected]>
AuthorDate: Mon Feb 22 14:00:42 2021 +0100

    [SYSTEMDS-2826] Sparse input support for CUDA codegen
     * Code template handling refactor
     * A few code snippets from the row template that the diff didn't cleanly 
separate
       (so things might not compile/run without the other commit 
(7bc6379d59a0c19d881fdac8229be64d880d30cc)).
       Intent was to split it in smaller chunks with moderate effort.
---
 pom.xml                                            |  11 +
 src/main/cuda/CMakeLists.txt                       |  32 +-
 src/main/cuda/headers/Matrix.h                     | 299 ++++++++++++
 src/main/cuda/headers/reduction.cuh                | 172 ++++---
 src/main/cuda/headers/spoof_utils.cuh              |  40 +-
 src/main/cuda/kernels/reduction.cu                 | 306 ++++++------
 src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp  | 151 +++---
 src/main/cuda/spoof-launcher/SpoofCUDAContext.h    | 531 +++++++++++++--------
 src/main/cuda/spoof-launcher/jni_bridge.cpp        | 187 ++++++--
 src/main/cuda/spoof-launcher/jni_bridge.h          |  18 +-
 src/main/cuda/spoof/cellwise.cu                    |  64 ++-
 .../org/apache/sysds/hops/codegen/cplan/CNode.java |  87 ++--
 .../sysds/hops/codegen/cplan/CNodeBinary.java      |  28 +-
 .../apache/sysds/hops/codegen/cplan/CNodeCell.java |  37 +-
 .../apache/sysds/hops/codegen/cplan/CNodeData.java |   3 +-
 .../apache/sysds/hops/codegen/cplan/CNodeNary.java |  17 +-
 .../sysds/hops/codegen/cplan/CNodeUnary.java       |  15 +-
 .../sysds/hops/codegen/cplan/CodeTemplate.java     |  58 ++-
 .../sysds/hops/codegen/cplan/cuda/Binary.java      |  31 +-
 .../sysds/hops/codegen/cplan/cuda/CellWise.java    |  77 ---
 .../sysds/hops/codegen/cplan/cuda/Ternary.java     |  40 +-
 .../sysds/hops/codegen/cplan/cuda/Unary.java       |  33 +-
 .../sysds/hops/codegen/cplan/java/Binary.java      |  26 +-
 .../sysds/hops/codegen/cplan/java/CellWise.java    |  79 ---
 .../Cellwise.java.template}                        |  31 +-
 .../sysds/hops/codegen/cplan/java/Ternary.java     |  27 +-
 .../sysds/hops/codegen/cplan/java/Unary.java       |  25 +-
 .../apache/sysds/runtime/codegen/SpoofCUDA.java    |  92 ++--
 .../controlprogram/context/ExecutionContext.java   |  21 +-
 .../instructions/gpu/SpoofCUDAInstruction.java     |  44 +-
 .../instructions/gpu/context/GPUObject.java        |  15 +-
 31 files changed, 1596 insertions(+), 1001 deletions(-)

diff --git a/pom.xml b/pom.xml
index 7212f14..f95eafd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -111,9 +111,20 @@
                                        <include>spoof_utils.cuh</include>
                                        <include>TempStorage.cuh</include>
                                        <include>utils.cuh</include>
+                                       <include>vector_write.cuh</include>
+                                       <include>vector_add.cuh</include>
+                                       <include>Matrix.h</include>
                                </includes>
                                <targetPath>cuda/headers</targetPath>
                        </resource>
+                       <resource>
+                               
<directory>src/main/java/org/apache/sysds/hops/codegen/cplan/java</directory>
+                               <includes>
+                                       
<include>Cellwise.java.template</include>
+                                       <include>Rowwise.java.template</include>
+                               </includes>
+                               <targetPath>java/spoof</targetPath>
+                       </resource>
                </resources>
 
                <plugins>
diff --git a/src/main/cuda/CMakeLists.txt b/src/main/cuda/CMakeLists.txt
index 8b74dee..cfa72c4 100644
--- a/src/main/cuda/CMakeLists.txt
+++ b/src/main/cuda/CMakeLists.txt
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
+cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
 
 # default to gcc 8.x while we're still supporting CUDA 10.x only
 if (UNIX)
@@ -49,8 +49,13 @@ target_include_directories(SystemDS PUBLIC 
"${CMAKE_SOURCE_DIR}/headers")
 
 find_package(CUDAToolkit REQUIRED)
 cmake_policy(SET CMP0104 NEW)
+
 set(CMAKE_CUDA_ARCHITECTURES  OFF)
 #ToDo: more compiler flag settings for Debug/Release compilation
+#set(CMAKE_CUDA_ARCHITECTURES 52 60 61 75 CACHE STRING "CUDA architectures" 
FORCE)
+#set(CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES})
+#message("CUDA_ARCHITECTURES: ${CUDA_ARCHITECTURES}")
+
 set(CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr")
 
 set_property(TARGET SystemDS PROPERTY CUDA_ARCHITECTURES 
${CMAKE_CUDA_ARCHITECTURES})
@@ -85,15 +90,28 @@ endif()
 
 set(SPOOF_HEADERS 
        spoof-launcher/jni_bridge.h
-       spoof-launcher/SpoofCUDAContext.h)
+       spoof-launcher/SpoofCUDAContext.h headers/Matrix.h 
headers/TempStorage.cuh)
 set(SPOOF_SOURCES 
        spoof-launcher/jni_bridge.cpp
        spoof-launcher/SpoofCUDAContext.cpp)
-
-add_library(spoof_cuda SHARED ${SPOOF_HEADERS} ${SPOOF_SOURCES} )
-
-target_include_directories(spoof_cuda PRIVATE "${CMAKE_SOURCE_DIR}/ext/jitify")
-target_link_libraries(spoof_cuda CUDA::nvrtc CUDA::cuda_driver CUDA::cudart)
+set(SPOOF_CUDA_HEADERS
+       headers/agg_ops.cuh
+       headers/reduction.cuh
+       headers/spoof_utils.cuh
+       headers/utils.cuh headers/operators.cuh headers/Matrix.h 
headers/vector_write.cuh headers/vector_add.cuh)
+               #headers/intellisense_cuda_intrinsics.h
+set(SPOOF_TEMPLATES
+       spoof/cellwise.cu    
+       spoof/rowwise.cu)
+source_group("SPOOF Templates" FILES ${SPOOF_TEMPLATES})
+set_source_files_properties( ${SPOOF_TEMPLATES} PROPERTIES HEADER_FILE_ONLY ON)
+add_library(spoof_cuda SHARED ${SPOOF_HEADERS} ${SPOOF_CUDA_HEADERS} 
${SPOOF_SOURCES} ${SPOOF_TEMPLATES})
+
+set_property(TARGET reduction PROPERTY CUDA_ARCHITECTURES 
${CMAKE_CUDA_ARCHITECTURES})
+set_property(TARGET spoof_cuda PROPERTY CUDA_ARCHITECTURES 
${CMAKE_CUDA_ARCHITECTURES})
+
+target_include_directories(spoof_cuda PRIVATE "${CMAKE_SOURCE_DIR}/ext/jitify" 
headers)
+target_link_libraries(spoof_cuda CUDA::nvrtc CUDA::cuda_driver CUDA::cudart 
CUDA::cublas)
 target_compile_features(spoof_cuda PUBLIC cxx_std_11)
 set_target_properties(spoof_cuda PROPERTIES POSITION_INDEPENDENT_CODE ON)
 set_target_properties(spoof_cuda PROPERTIES OUTPUT_NAME 
"systemds_spoof_cuda-${CMAKE_SYSTEM_NAME}-${CMAKE_SYSTEM_PROCESSOR}")
diff --git a/src/main/cuda/headers/Matrix.h b/src/main/cuda/headers/Matrix.h
new file mode 100644
index 0000000..a11fc33
--- /dev/null
+++ b/src/main/cuda/headers/Matrix.h
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#pragma once
+#ifndef SYSTEMDS_MATRIX_H
+#define SYSTEMDS_MATRIX_H
+
+using uint32_t = unsigned int;
+
+template <typename T>
+struct Matrix {
+       T* data;
+       uint32_t* row_ptr;
+       uint32_t* col_idx;
+
+       uint32_t rows;
+       uint32_t cols;
+       uint32_t nnz;
+       typedef T value_type;
+};
+
+//#ifdef __CUDACC_RTC__
+#ifdef __CUDACC__
+
+template<typename T>
+uint32_t bin_search(T* values, uint32_t lower, uint32_t upper, T val) {
+       upper -= 1;
+       while(lower <= (upper-1)) {
+               uint32_t idx = lower + upper >> 1;
+               uint32_t vi = values[idx];
+               if (vi < val)
+                       lower = idx + 1;
+               else {
+                       if (vi <= val)
+                               return idx;
+                       upper = idx - 1;
+               }
+       }
+       return upper + 1;
+}
+
+template<typename T>
+class MatrixAccessor {
+       
+       Matrix<T>* _mat;
+       
+       // Member function pointers
+       uint32_t (MatrixAccessor::*_len)();
+       uint32_t (MatrixAccessor::*_row_len)(uint32_t);
+       uint32_t (MatrixAccessor::*_pos)(uint32_t);
+       uint32_t* (MatrixAccessor::*_col_idxs)(uint32_t);
+       
+       T& (MatrixAccessor::*_val_i)(uint32_t);
+       T& (MatrixAccessor::*_val_rc)(uint32_t, uint32_t);
+       T* (MatrixAccessor::*_vals)(uint32_t);
+       void (MatrixAccessor::*_set)(uint32_t, uint32_t, T);
+
+public:
+       MatrixAccessor() = default;
+
+       __device__ MatrixAccessor(Matrix<T>* mat) { init(mat); }
+
+       __device__ void init(Matrix<T>* mat) {
+               _mat = mat;
+               
+               if (_mat->row_ptr == nullptr) {
+                       _len = &MatrixAccessor::len_dense;
+                       _pos = &MatrixAccessor::pos_dense;
+                       _col_idxs = &MatrixAccessor::cols_dense;
+                       _val_rc = &MatrixAccessor::val_dense_rc;
+                       _vals = &MatrixAccessor::vals_dense;
+                       _row_len = &MatrixAccessor::row_len_dense;
+                       _val_i = &MatrixAccessor::val_dense_i;
+               } else {
+                       _len = &MatrixAccessor::len_sparse;
+                       _pos = &MatrixAccessor::pos_sparse;
+                       _col_idxs = &MatrixAccessor::cols_sparse;
+                       _val_rc = &MatrixAccessor::val_sparse_rc;
+                       _vals = &MatrixAccessor::vals_sparse;
+                       _row_len = &MatrixAccessor::row_len_sparse;
+                       _val_i = &MatrixAccessor::val_sparse_i;
+                       _set = &MatrixAccessor::set_sparse;
+               }
+       }
+       
+       __device__ uint32_t& nnz() { return _mat->nnz; }
+       __device__ uint32_t cols() { return _mat->cols; }
+       __device__ uint32_t rows() { return _mat->rows; }
+       
+       __device__ uint32_t len() { return (this->*_len)(); }
+       
+       __device__ uint32_t pos(uint32_t rix) {
+               return (this->*_pos)(rix);
+       }
+       
+       __device__ T& val(uint32_t r, uint32_t c) {
+               return (this->*_val_rc)(r, c);
+       }
+       
+       __device__ T& val(uint32_t i) {
+               return (this->*_val_i)(i);
+       }
+
+       __device__ T* vals(uint32_t rix) {
+               return (this->*_vals)(rix);
+       }
+       
+    __device__ T& operator[](uint32_t i) {
+               return (this->*_val_i)(i);
+    }
+       
+       __device__ uint32_t row_len(uint32_t rix) {
+               return (this->*_row_len)(rix);
+       }
+       
+       __device__ uint32_t* col_idxs(uint32_t rix) {
+               return (this->*_col_idxs)(rix);
+       }
+
+       __device__ void set(uint32_t r, uint32_t c, T v) {
+               (this->*_set)(r,c,v);
+       }
+       
+       __device__ uint32_t* indexes() {
+               return _mat->row_ptr;
+       }
+private:
+       __device__ uint32_t len_dense() {
+               return _mat->rows * _mat->cols;
+       }
+       
+       __device__ uint32_t pos_dense(uint32_t rix) {
+               return _mat->cols * rix;
+       }
+       
+       __device__ uint32_t* cols_dense(uint32_t rix) {
+               printf("ERROR: no column indices array in a dense matrix! This 
will likely crash :-/\n");
+               return nullptr;
+       }
+       
+       __device__ T& val_dense_rc(uint32_t r, uint32_t c) {
+               return _mat->data[_mat->cols * r + c];
+       }
+       
+       __device__ T& val_dense_i(uint32_t i) {
+               return _mat->data[i];
+       }
+       
+       __device__ T* vals_dense(uint32_t rix) {
+               return &(_mat->data[rix]);
+       }
+       
+       __device__ uint32_t row_len_dense(uint32_t rix) {
+               return _mat->rows;
+       }
+       
+       //ToDo sparse accessors
+       __device__ uint32_t len_sparse() {
+               return _mat->nnz;
+       }
+       
+       __device__ uint32_t pos_sparse(uint32_t rix) {
+               return _mat->row_ptr[rix];
+       }
+       
+       __device__ uint32_t* cols_sparse(uint32_t rix) {
+               return &_mat->col_idx[_mat->row_ptr[rix]];
+       }
+       
+       __device__ T& val_sparse_rc(uint32_t r, uint32_t c) {
+//             printf("TBI: val_sparse_rc\n");
+//             asm("trap;");
+
+               return _mat->data[0];
+       }
+       
+       __device__ T& val_sparse_i(uint32_t i) {
+               return _mat->data[i];
+       }
+       
+       __device__ T* vals_sparse(uint32_t rix) {
+               return &_mat->data[_mat->row_ptr[rix]];
+       }
+       
+       __device__ uint32_t row_len_sparse(uint32_t rix) {
+               return _mat->row_ptr[rix+1]-_mat->row_ptr[rix];
+       }
+       
+       __device__ void set_sparse(uint32_t idx, uint32_t c, T v) {
+//             uint32_t idx = _mat->cols*r+c;
+               _mat->data[idx] = v;
+               _mat->col_idx[idx] = c;
+//             _mat->row_ptr[r+1] = _mat->row_ptr[r+1] > 0 ? min(idx, 
_mat->row_ptr[r+1]) : idx;
+//             if(threadIdx.x == 0)
+//             atomicMax(&(_mat->row_ptr[r+1]), idx < _mat->nnz-1 ? idx+1 : 
idx);
+//             v == 0.0 ? atomicAdd(&(_mat->nnz), -1) : 
atomicAdd(&(_mat->nnz), -1);
+               
+//             if(blockIdx.x == 0 && threadIdx.x > 20 && threadIdx.x < 30)
+//                     printf("nnz=%d idx=%d r=%d c=%d v=%4.3f\n",  _mat->nnz, 
idx, r, c, v);
+//             _mat->row_ptr[r+1] = _mat->row_ptr[r+1] > 0 ? max(idx, 
_mat->row_ptr[r+1]) : idx;
+       }
+};
+#endif
+
+
+#ifdef __CUDACC_RTC__
+
+//ToDo: move to separate file
+template <typename T>
+struct Vector {
+       T* data;
+       uint32_t length;
+
+       __device__ T* vals(uint32_t idx) { return &data[idx]; }
+
+       __device__ T& operator[](uint32_t idx) {
+           return data[idx];
+    }
+       
+       __device__ void print(const char* name, uint32_t end_ = 0, uint32_t 
start = 0, uint32_t bID = 0, uint32_t tID = 0) {
+               if(blockIdx.x == bID && threadIdx.x==tID) {
+                       uint32_t end = end_;
+                       if(end > 0)
+                               end = min(end, length);
+                       printf("%s: ", name);
+                       for(auto i = start; i < end; ++i)
+                               print("%4.3f ", data[i]);
+               }
+       }
+};
+
+template <typename T, uint32_t ELEMENTS>
+class RingBuffer {
+       Vector<T> vec[ELEMENTS];
+       int32_t pos;
+
+public:
+       __device__ void init(uint32_t offset, uint32_t length, T* buffer) {
+               pos = -1;
+               for(auto i=0;i<ELEMENTS;++i) {
+                       vec[i].data = &buffer[offset + length * i];
+                       vec[i].length = length;
+               }
+       }
+
+       __device__ Vector<T>& next() {
+               pos = (pos+1>=ELEMENTS) ? 0 : pos+1;
+               __syncthreads();
+               return vec[pos];
+       }
+};
+
+template <typename T, int NUM_B>
+struct SpoofOp {
+       MatrixAccessor<T> a;
+       MatrixAccessor<T> b[NUM_B];
+       MatrixAccessor<T> c;
+       T* scalars;
+       uint32_t grix;
+       T* avals;
+       uint32_t* aix;
+       uint32_t alen;
+       
+       SpoofOp(Matrix<T>* A, Matrix<T>* B, Matrix<T>* C, T* scalars, T* 
tmp_stor, uint32_t grix) :
+                       scalars(scalars), grix(grix), avals(A->data), 
aix(A->col_idx) {
+               a.init(A);
+               c.init(C);
+               alen = a.row_len(grix);
+
+               if(B)
+                       for(auto i = 0; i < NUM_B; ++i)
+                               b[i].init(&(B[i]));
+       }
+       
+//     __device__ Vector<T>& getTempStorage(uint32_t len) {
+//             Vector<T>& vec = temp_rb.next();
+//             tvec.length = len;
+//             return vec;
+//     }
+};
+#endif // __CUDACC_RTC__
+
+#endif //SYSTEMDS_MATRIX_H
diff --git a/src/main/cuda/headers/reduction.cuh 
b/src/main/cuda/headers/reduction.cuh
index 56845b5..8a024cc 100644
--- a/src/main/cuda/headers/reduction.cuh
+++ b/src/main/cuda/headers/reduction.cuh
@@ -25,6 +25,7 @@ using uint = unsigned int;
 #include <cuda_runtime.h>
 
 #include "utils.cuh"
+#include "Matrix.h"
 
 /**
  * Does a reduce operation over all elements of the array.
@@ -50,15 +51,7 @@ using uint = unsigned int;
  * @param SpoofCellwiseOp              initial value for the reduction variable
  */
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
-__device__ void FULL_AGG(
-               T *g_idata, ///< input data stored in device memory (of size n)
-               T *g_odata, ///< output/temporary array stored in device memory 
(of size n)
-               uint m,
-               uint n,
-               T initialValue, 
-               ReductionOp reduction_op, 
-           SpoofCellwiseOp spoof_op)
-{
+__device__ void FULL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT, ReductionOp reduction_op, SpoofCellwiseOp spoof_op) {
        auto sdata = shared_memory_proxy<T>();
 
        // perform first level of reduction,
@@ -66,20 +59,19 @@ __device__ void FULL_AGG(
        uint tid = threadIdx.x;
        uint i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
        uint gridSize = blockDim.x * 2 * gridDim.x;
-       uint N = m * n;
-       T v = initialValue;
+
+       T v = reduction_op.init();
 
        // we reduce multiple elements per thread.  The number is determined by 
the
        // number of active thread blocks (via gridDim).  More blocks will 
result
        // in a larger gridSize and therefore fewer elements per thread
        while (i < N) {
-               v = reduction_op(v, spoof_op(g_idata[i], i));
+               v = reduction_op(v, spoof_op(*(in->vals(i)), i, i / N, i % N));
 
-               if (i + blockDim.x < N) 
-               {
+               if (i + blockDim.x < N) {
                        //__syncthreads();
                        //printf("loop fetch i(%d)+blockDim.x(%d)=%d, 
in=%f\n",i, blockDim.x, i + blockDim.x, g_idata[i + blockDim.x]);
-                       v = reduction_op(v, spoof_op(g_idata[i + blockDim.x], 
blockDim.x + i));
+                       v = reduction_op(v, spoof_op(*(in->vals(i+blockDim.x)), 
blockDim.x + i, (i+blockDim.x) / in->cols(), (i+blockDim.x) % in->cols()));
                }
 
                i += gridSize;
@@ -140,12 +132,12 @@ __device__ void FULL_AGG(
                }
        }
 
-       // write result for this block to global mem
-       if (tid == 0) {
-               if(gridDim.x < 10)
-                       printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
-               g_odata[blockIdx.x] = sdata[0];
-       }
+        // write result for this block to global mem
+        if (tid == 0) {
+//             if(gridDim.x < 10)
+//                     printf("blockIdx.x=%d reduction result: %3.1f\n", 
blockIdx.x, sdata[0]);
+               out->val(0, blockIdx.x) = sdata[0];
+        }
 }
 
 /**
@@ -168,32 +160,35 @@ __device__ void FULL_AGG(
  * the value before writing it to its final location in global memory for each
  * row
  */
+//template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
+//__device__ void ROW_AGG(
+//             T *g_idata, ///< input data stored in device memory (of size 
rows*cols)
+//             T *g_odata,  ///< output/temporary array store in device memory 
(of size
+//             /// rows*cols)
+//             uint rows,  ///< rows in input and temporary/output arrays
+//             uint cols,  ///< columns in input and temporary/output arrays
+//             T initialValue,  ///< initial value for the reduction variable
+//             ReductionOp reduction_op, ///< Reduction operation to perform 
(functor object)
+//             SpoofCellwiseOp spoof_op) ///< Operation to perform before 
assigning this
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
-__device__ void ROW_AGG(
-               T *g_idata, ///< input data stored in device memory (of size 
rows*cols)
-               T *g_odata,  ///< output/temporary array store in device memory 
(of size
-               /// rows*cols)
-               uint rows,  ///< rows in input and temporary/output arrays
-               uint cols,  ///< columns in input and temporary/output arrays
-               T initialValue,  ///< initial value for the reduction variable
-               ReductionOp reduction_op, ///< Reduction operation to perform 
(functor object)
-               SpoofCellwiseOp spoof_op) ///< Operation to perform before 
assigning this
-{
+__device__ void ROW_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT,  ReductionOp reduction_op,
+                                          SpoofCellwiseOp spoof_op) {
        auto sdata = shared_memory_proxy<T>();
 
        // one block per row
-       if (blockIdx.x >= rows) {
+       if (blockIdx.x >= in->rows()) {
                return;
        }
 
        uint block = blockIdx.x;
        uint tid = threadIdx.x;
-       uint i = tid;
-       uint block_offset = block * cols;
+       uint32_t i = tid;
+       uint block_offset = block * in->cols();
 
-       T v = initialValue;
-       while (i < cols) {
-               v = reduction_op(v, spoof_op(g_idata[block_offset + i], i));
+//     T v = initialValue;
+       T v = reduction_op.init();
+       while (i < in->cols()) {
+               v = reduction_op(v, spoof_op(in->val(block_offset + i), i, i / 
in->cols(), i % in->cols()));
                i += blockDim.x;
        }
 
@@ -254,7 +249,7 @@ __device__ void ROW_AGG(
 
        // write result for this block to global mem, modify it with assignment 
op
        if (tid == 0)
-               g_odata[block] = sdata[0];
+               out->val(block) = sdata[0];
 }
 
 /**
@@ -273,41 +268,98 @@ __device__ void ROW_AGG(
  * column
  */
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
-__device__ void COL_AGG(T *g_idata, ///< input data stored in device memory 
(of size rows*cols)
-               T *g_odata,  ///< output/temporary array store in device memory 
(of size rows*cols)
-               uint rows,  ///< rows in input and temporary/output arrays
-               uint cols,  ///< columns in input and temporary/output arrays
-               T initialValue,  ///< initial value for the reduction variable
-               ReductionOp reduction_op, ///< Reduction operation to perform 
(functor object)
-               SpoofCellwiseOp spoof_op) ///< Operation to perform before 
aggregation
-               
-{
+__device__ void COL_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT,  ReductionOp reduction_op,
+                                               SpoofCellwiseOp spoof_op) {
+//__device__ void COL_AGG(T *g_idata, ///< input data stored in device memory 
(of size rows*cols)
+//             T *g_odata,  ///< output/temporary array store in device memory 
(of size rows*cols)
+//             uint rows,  ///< rows in input and temporary/output arrays
+//             uint cols,  ///< columns in input and temporary/output arrays
+//             T initialValue,  ///< initial value for the reduction variable
+//             ReductionOp reduction_op, ///< Reduction operation to perform 
(functor object)
+//             SpoofCellwiseOp spoof_op) ///< Operation to perform before 
aggregation
+//
+//{
        uint global_tid = blockIdx.x * blockDim.x + threadIdx.x;
-       if (global_tid >= cols) {
+       if (global_tid >= in->cols()) {
                return;
        }
 
        uint i = global_tid;
-       uint grid_size = cols;
-       T val = initialValue;
+       uint grid_size = in->cols();
+       T val = reduction_op.init();
 
-       while (i < rows * cols) {
-               val = reduction_op(val, spoof_op(g_idata[i], i));
+       while (i < N) {
+               val = reduction_op(val, spoof_op(in->val(i), i, i / in->cols(), 
i % in->cols()));
                i += grid_size;
        }
-       g_odata[global_tid] = val;
+       out->val(global_tid) = val;
 }
 
 template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
-__device__ void NO_AGG(T* g_idata, T* g_odata,  uint rows, uint cols,
-       T VT,  ReductionOp reduction_op, SpoofCellwiseOp spoof_op) 
+__device__ void NO_AGG(MatrixAccessor<T>* in, MatrixAccessor<T>* out, uint32_t 
N, T VT,  ReductionOp reduction_op, SpoofCellwiseOp spoof_op)
 {
-       int tid = blockIdx.x * blockDim.x + threadIdx.x;
-       int first_idx = tid * static_cast<int>(VT);
-       int last_idx = min(first_idx + static_cast<int>(VT), spoof_op.m * 
spoof_op.n);
+       uint32_t gtid = blockIdx.x * blockDim.x + threadIdx.x;
+       uint32_t first_idx = gtid * static_cast<uint32_t>(VT);
+       uint32_t last_idx = min(first_idx + static_cast<uint32_t>(VT), N);
        #pragma unroll
-       for(int i = first_idx; i < last_idx; i++) {
-               g_odata[i] = spoof_op(g_idata[i], i);
+       for(auto i = first_idx; i < last_idx; i++) {
+               T result = spoof_op(in->vals(0)[i], i, i / in->cols(), i % 
in->cols());
+               out->vals(0)[i] = result;
+               //if(i < 4)
+               //      printf("tid=%d in=%4.3f res=%4.3f out=%4.3f r=%d\n", i, 
in->vals(0)[i], result, out->vals(0)[i], i/in->cols());
+       }
+}
+
+template<typename T, typename ReductionOp, typename SpoofCellwiseOp>
+__device__ void NO_AGG_SPARSE(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t N, T VT,  ReductionOp reduction_op, SpoofCellwiseOp spoof_op)
+{
+       const uint32_t& rix = blockIdx.x;
+       uint32_t tid = threadIdx.x;
+//     uint32_t rix = (gtid * VT) / in->cols();
+//     //uint32_t cix = (gtid % in->cols());// *static_cast<uint32_t>(VT);
+//     uint32_t cix = in->col_idxs(0)[gtid];
+       uint32_t row_start = in->pos(rix);
+       uint32_t row_len = in->row_len(rix);
+////   if(cix == 0) {
+//     //if(row_start == gtid) {
+//     //if(threadIdx.x == 0) {
+//     if(rix < in->rows()) {
+//             if(rix > 895 && rix < 905)
+//                     printf("gtid=%d in->indexes()[rix=%d]=%d rowlen=%d 
row_start=%d cix=%d\n", gtid, rix, in->indexes()[rix], in->row_len(rix), 
row_start, cix);
+////           out->indexes()[gtid] = in->indexes()[gtid];
+//             out->indexes()[rix] = in->indexes()[rix];
+//     }
+
+
+       while(tid < row_len) {
+
+               uint32_t* aix = in->col_idxs(rix);
+               uint32_t cix = aix[tid];
+//             T result = spoof_op(in->val(rix, cix), rix*in->rows()+cix, rix, 
cix);
+               T result = spoof_op(in->val(row_start+tid), rix*in->rows()+cix, 
rix, cix);
+               out->set(row_start+tid, cix, result);
+
+               //if(rix > 899 && rix < 903 && cix==0)
+               //      printf("rix=%d row_start=%d tid=%d result=%4.3f\n", 
rix, row_start, tid, result);
+
+               tid+=blockDim.x;
+
+
+//#pragma unroll
+//             for (auto i = first_idx; i < last_idx; i++) {
+////           out->vals(0)[i] = spoof_op(in->vals(0)[i], i);
+////           out->col_idxs(0)[i] = gtid % blockDim.x;
+//                     T result = spoof_op(in->vals(0)[i], i);
+//                     out->vals(0)[i] = result;
+//                     //out->col_idxs(0)[i] = i % in->cols();
+//                     out->col_idxs(0)[i] = in->col_idxs(0)[i];
+//                     //out->set(i/in->cols(), i%in->cols(), result);
+//                     //out->set(rix, i%in->cols(), result);
+//                     if (i > in->nnz() - 10)
+//                             printf("i=%d in=%4.3f res=%4.3f out=%4.3f r=%d 
out->index(i=%d)=%d out->col_idxs()[i=%d]=%d first=%d last=%d gtid=%d\n",
+//                                        i, in->vals(0)[i], result, 
out->vals(0)[i],
+//                                        i / in->cols(), i, 
out->indexes()[i], i, out->col_idxs(0)[i], first_idx, last_idx, gtid);
+//             }
        }
 }
 
diff --git a/src/main/cuda/headers/spoof_utils.cuh 
b/src/main/cuda/headers/spoof_utils.cuh
index e28d254..dab9aec 100644
--- a/src/main/cuda/headers/spoof_utils.cuh
+++ b/src/main/cuda/headers/spoof_utils.cuh
@@ -28,21 +28,31 @@ __constant__ double FLOAT_EPS = 1.49012E-08; // 2 ^ -26
 __constant__ double EPSILON = 1E-11; // margin for comparisons ToDo: make 
consistent use of it
 
 __device__ long long toInt64(double a) {
-    return (signbit(a) == 0 ? 1.0 : -1.0) * abs(floor(a + DOUBLE_EPS));
+       return (signbit(a) == 0 ? 1.0 : -1.0) * abs(floor(a + DOUBLE_EPS));
 }
 
 __device__ int toInt32(float a) {
-    return (signbit(a) == 0 ? 1.0 : -1.0) * abs(floor(a + FLOAT_EPS));
+       return (signbit(a) == 0 ? 1.0 : -1.0) * abs(floor(a + FLOAT_EPS));
 }
 
 template<typename T>
 __device__ T getValue(T* data, int rowIndex) {
-    return data[rowIndex];
+       return data[rowIndex];
 }
 
 template<typename T>
-__device__ T getValue(T* data, int n, int rowIndex, int colIndex) {
-    return data[rowIndex * n + colIndex];
+__device__ T getValue(MatrixAccessor<T> data, int rowIndex) {
+       return data[rowIndex];
+}
+
+template<typename T>
+__device__ T getValue(T* data, uint32_t n, uint32_t rowIndex, uint32_t 
colIndex) {
+       return data[rowIndex * n + colIndex];
+}
+
+template<typename T>
+__device__ T getValue(MatrixAccessor<T>& data, uint32_t n, uint32_t rowIndex, 
uint32_t colIndex) {
+       return data[rowIndex * n + colIndex];
 }
 
 template<typename T>
@@ -50,14 +60,14 @@ __device__ T intDiv(T a, T b);
 
 template<>
 __device__ double intDiv(double a, double b) {
-    double ret = a / b;
-    return (isnan(ret) || isinf(ret)) ? ret : toInt64(ret);
+       double ret = a / b;
+       return (isnan(ret) || isinf(ret)) ? ret : toInt64(ret);
 }
 
 template<>
 __device__ float intDiv(float a, float b) {
-    float ret = a / b;
-    return (isnan(ret) || isinf(ret)) ? ret : toInt32(ret);
+       float ret = a / b;
+       return (isnan(ret) || isinf(ret)) ? ret : toInt32(ret);
 }
 
 template<typename T>
@@ -65,16 +75,16 @@ __device__ T modulus(T a, T b);
 
 template<>
 __device__ double modulus(double a, double b) {
-    if (fabs(b) < DOUBLE_EPS)
-        return CUDART_NAN;
-    return a - intDiv(a, b) * b;
+       if (fabs(b) < DOUBLE_EPS)
+               return CUDART_NAN;
+       return a - intDiv(a, b) * b;
 }
 
 template<>
 __device__ float modulus(float a, float b) {
-    if (fabs(b) < FLOAT_EPS)
-        return CUDART_NAN_F;
-    return a - intDiv(a, b) * b;
+       if (fabs(b) < FLOAT_EPS)
+               return CUDART_NAN_F;
+       return a - intDiv(a, b) * b;
 }
 
 template<typename T>
diff --git a/src/main/cuda/kernels/reduction.cu 
b/src/main/cuda/kernels/reduction.cu
index 04fd098..3a11f77 100644
--- a/src/main/cuda/kernels/reduction.cu
+++ b/src/main/cuda/kernels/reduction.cu
@@ -20,9 +20,11 @@
 #include "utils.cuh"
 #include "agg_ops.cuh"
 #include "reduction.cuh"
+#include "Matrix.h"
 
 using uint = unsigned int;
 #include <cuda_runtime.h>
+#ifdef __CUDACC__
 
 /**
  * Do a summation over all elements of an array/matrix
@@ -31,19 +33,21 @@ using uint = unsigned int;
  * @param n         size of the input and temporary/output arrays
  */
 template<typename T>
-__device__ void reduce_sum(T *g_idata, T *g_odata, uint n) {
+__device__ void reduce_sum(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t n) {
        SumOp<T> agg_op;        
        IdentityOp<T> spoof_op;
-       FULL_AGG<T, SumOp<T>, IdentityOp<T>>(g_idata, g_odata, n, 1, (T) 0.0, 
agg_op, spoof_op);
+       FULL_AGG<T, SumOp<T>, IdentityOp<T>>(in, out, n, (T) 0.0, agg_op, 
spoof_op);
 }
 
-extern "C" __global__ void reduce_sum_d(double *g_idata, double *g_odata, uint 
n) {
-       reduce_sum(g_idata, g_odata, n);
-}
-
-extern "C" __global__ void reduce_sum_f(float *g_idata, float *g_odata, uint 
n) {
-       reduce_sum(g_idata, g_odata, n);
+extern "C" __global__ void reduce_sum_d(Matrix<double>* in, Matrix<double>* 
out, uint32_t n) {
+       MatrixAccessor<double> _in(in);
+       MatrixAccessor<double> _out(out);
+       reduce_sum(&_in, &_out, n);
 }
+#endif
+//extern "C" __global__ void reduce_sum_f(float *g_idata, float *g_odata, uint 
n) {
+//     reduce_sum(g_idata, g_odata, n);
+//}
 
 /**
  * Do a summation over all rows of a matrix
@@ -52,20 +56,20 @@ extern "C" __global__ void reduce_sum_f(float *g_idata, 
float *g_odata, uint n)
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_row_sum(T *g_idata, T *g_odata, uint rows, uint cols) {
-       SumOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       ROW_AGG<T, SumOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 0.0, 
agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_row_sum_d(double *g_idata, double *g_odata, 
uint rows, uint cols) {
-       reduce_row_sum(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_row_sum_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
-       reduce_row_sum(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_row_sum(T *g_idata, T *g_odata, uint rows, uint cols) 
{
+//     SumOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     ROW_AGG<T, SumOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 0.0, 
agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_row_sum_d(double *g_idata, double 
*g_odata, uint rows, uint cols) {
+//     reduce_row_sum(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_row_sum_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
+//     reduce_row_sum(g_idata, g_odata, rows, cols);
+//}
 
 /**
  * Do a summation over all columns of a matrix
@@ -74,20 +78,20 @@ extern "C" __global__ void reduce_row_sum_f(float *g_idata, 
float *g_odata, uint
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_col_sum(T *g_idata, T *g_odata, uint rows, uint cols) {
-       SumOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       COL_AGG<T, SumOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
(T)0.0, agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_col_sum_d(double *g_idata, double *g_odata, 
uint rows, uint cols) {
-       reduce_col_sum(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_col_sum_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
-       reduce_col_sum(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_col_sum(T *g_idata, T *g_odata, uint rows, uint cols) 
{
+//     SumOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     COL_AGG<T, SumOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
(T)0.0, agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_col_sum_d(double *g_idata, double 
*g_odata, uint rows, uint cols) {
+//     reduce_col_sum(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_col_sum_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
+//     reduce_col_sum(g_idata, g_odata, rows, cols);
+//}
 
 
 /**
@@ -97,19 +101,21 @@ extern "C" __global__ void reduce_col_sum_f(float 
*g_idata, float *g_odata, uint
  * @param n         size of the input and temporary/output arrays
  */
 template<typename T>
-__device__ void reduce_max(T *g_idata, T *g_odata, uint n) {
+__device__ void reduce_max(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t n) {
        MaxOp<T> agg_op;
        IdentityOp<T> spoof_op;
-       FULL_AGG<T, MaxOp<T>, IdentityOp<T>>(g_idata, g_odata, n, 1, -MAX<T>(), 
agg_op, spoof_op);
+       FULL_AGG<T, MaxOp<T>, IdentityOp<T>>(in, out, n, -MAX<T>(), agg_op, 
spoof_op);
 }
 
-extern "C" __global__ void reduce_max_d(double *g_idata, double *g_odata, uint 
n) {
-       reduce_max(g_idata, g_odata, n);
+extern "C" __global__ void reduce_max_d(Matrix<double>* in, Matrix<double>* 
out, uint32_t n) {
+       MatrixAccessor<double> _in(in);
+       MatrixAccessor<double> _out(out);
+       reduce_max(&_in, &_out, n);
 }
 
-extern "C" __global__ void reduce_max_f(float *g_idata, float *g_odata, uint 
n) {
-       reduce_max(g_idata, g_odata, n);
-}
+//extern "C" __global__ void reduce_max_f(float *g_idata, float *g_odata, uint 
n) {
+//     reduce_max(g_idata, g_odata, n);
+//}
 
 /**
  * Do a max over all rows of a matrix
@@ -118,20 +124,20 @@ extern "C" __global__ void reduce_max_f(float *g_idata, 
float *g_odata, uint n)
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_row_max(T *g_idata, T *g_odata, uint rows, uint cols) {
-       MaxOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       ROW_AGG<T, MaxOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
-MAX<T>(), agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_row_max_d(double *g_idata, double *g_odata, 
uint rows, uint cols) {
-       reduce_row_max(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_row_max_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
-       reduce_row_max(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_row_max(T *g_idata, T *g_odata, uint rows, uint cols) 
{
+//     MaxOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     ROW_AGG<T, MaxOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
-MAX<T>(), agg_op, spoof_op);
+//}
+
+//extern "C" __global__ void reduce_row_max_d(double *g_idata, double 
*g_odata, uint rows, uint cols) {
+//     reduce_row_max(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_row_max_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
+//     reduce_row_max(g_idata, g_odata, rows, cols);
+//}
 
 /**
  * Do a max over all columns of a matrix
@@ -140,20 +146,20 @@ extern "C" __global__ void reduce_row_max_f(float 
*g_idata, float *g_odata, uint
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_col_max(T *g_idata, T *g_odata, uint rows, uint cols) {
-       MaxOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       COL_AGG<T, MaxOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
-MAX<T>(), agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_col_max_d(double *g_idata, double *g_odata, 
uint rows, uint cols) {
-       reduce_col_max(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_col_max_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
-       reduce_col_max(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_col_max(T *g_idata, T *g_odata, uint rows, uint cols) 
{
+//     MaxOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     COL_AGG<T, MaxOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
-MAX<T>(), agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_col_max_d(double *g_idata, double 
*g_odata, uint rows, uint cols) {
+//     reduce_col_max(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_col_max_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
+//     reduce_col_max(g_idata, g_odata, rows, cols);
+//}
 
 
 /**
@@ -163,19 +169,21 @@ extern "C" __global__ void reduce_col_max_f(float 
*g_idata, float *g_odata, uint
  * @param n         size of the input and temporary/output arrays
  */
 template<typename T>
-__device__ void reduce_min(T *g_idata, T *g_odata, uint n) {
+__device__ void reduce_min(MatrixAccessor<T>* in, MatrixAccessor<T>* out, 
uint32_t n) {
        MinOp<T> agg_op;
        IdentityOp<T> spoof_op;
-       FULL_AGG<T, MinOp<T>, IdentityOp<T>>(g_idata, g_odata, n, 1, MAX<T>(), 
agg_op, spoof_op);
+       FULL_AGG<T, MinOp<T>, IdentityOp<T>>(in, out, n, MAX<T>(), agg_op, 
spoof_op);
 }
 
-extern "C" __global__ void reduce_min_d(double *g_idata, double *g_odata, uint 
n) {
-       reduce_min(g_idata, g_odata, n);
+extern "C" __global__ void reduce_min_d(Matrix<double>* in, Matrix<double>* 
out, uint32_t n) {
+       MatrixAccessor<double> _in(in);
+       MatrixAccessor<double> _out(out);
+       reduce_min(&_in, &_out, n);
 }
 
-extern "C" __global__ void reduce_min_f(float *g_idata, float *g_odata, uint 
n) {
-       reduce_min(g_idata, g_odata, n);
-}
+//extern "C" __global__ void reduce_min_f(float *g_idata, float *g_odata, uint 
n) {
+//     reduce_min(g_idata, g_odata, n);
+//}
 
 
 /**
@@ -185,20 +193,20 @@ extern "C" __global__ void reduce_min_f(float *g_idata, 
float *g_odata, uint n)
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_row_min(T *g_idata, T *g_odata, uint rows, uint cols) {
-       MinOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       ROW_AGG<T, MinOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
MAX<T>(), agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_row_min_d(double *g_idata, double *g_odata, 
uint rows, uint cols) {
-       reduce_row_min(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_row_min_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
-       reduce_row_min(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_row_min(T *g_idata, T *g_odata, uint rows, uint cols) 
{
+//     MinOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     ROW_AGG<T, MinOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
MAX<T>(), agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_row_min_d(double *g_idata, double 
*g_odata, uint rows, uint cols) {
+//     reduce_row_min(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_row_min_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
+//     reduce_row_min(g_idata, g_odata, rows, cols);
+//}
 
 /**
  * Do a min over all columns of a matrix
@@ -207,20 +215,20 @@ extern "C" __global__ void reduce_row_min_f(float 
*g_idata, float *g_odata, uint
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_col_min(T *g_idata, T *g_odata, uint rows, uint cols) {
-       MinOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       COL_AGG<T, MinOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
MAX<T>(), agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_col_min_d(double *g_idata, double *g_odata, 
uint rows, uint cols) {
-       reduce_col_min(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_col_min_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
-       reduce_col_min(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_col_min(T *g_idata, T *g_odata, uint rows, uint cols) 
{
+//     MinOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     COL_AGG<T, MinOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
MAX<T>(), agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_col_min_d(double *g_idata, double 
*g_odata, uint rows, uint cols) {
+//     reduce_col_min(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_col_min_f(float *g_idata, float *g_odata, 
uint rows, uint cols) {
+//     reduce_col_min(g_idata, g_odata, rows, cols);
+//}
 
 
 /**
@@ -229,20 +237,20 @@ extern "C" __global__ void reduce_col_min_f(float 
*g_idata, float *g_odata, uint
  * @param g_odata   output/temporary array stored in device memory (of size n)
  * @param n         size of the input and temporary/output arrays
  */
-template<typename T>
-__device__ void reduce_sum_sq(T *g_idata, T *g_odata, uint n) {
-       SumSqOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       FULL_AGG<T, SumSqOp<T>, IdentityOp<T>>(g_idata, g_odata, n, 1, (T) 0.0, 
agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_sum_sq_d(double *g_idata, double *g_odata, 
uint n) {
-       reduce_sum_sq(g_idata, g_odata, n);
-}
-
-extern "C" __global__ void reduce_sum_sq_f(float *g_idata, float *g_odata, 
uint n) {
-       reduce_sum_sq(g_idata, g_odata, n);
-}
+//template<typename T>
+//__device__ void reduce_sum_sq(T *g_idata, T *g_odata, uint n) {
+//     SumSqOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     FULL_AGG<T, SumSqOp<T>, IdentityOp<T>>(g_idata, g_odata, n, 1, (T) 0.0, 
agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_sum_sq_d(double *g_idata, double *g_odata, 
uint n) {
+//     reduce_sum_sq(g_idata, g_odata, n);
+//}
+//
+//extern "C" __global__ void reduce_sum_sq_f(float *g_idata, float *g_odata, 
uint n) {
+//     reduce_sum_sq(g_idata, g_odata, n);
+//}
 
 /**
  * Do a summation over all squared elements of an array/matrix
@@ -251,32 +259,32 @@ extern "C" __global__ void reduce_sum_sq_f(float 
*g_idata, float *g_odata, uint
  * @param rows      number of rows in input matrix
  * @param cols      number of columns in input matrix
  */
-template<typename T>
-__device__ void reduce_col_sum_sq(T* g_idata, T* g_odata, uint rows, uint 
cols) {
-       SumSqOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       COL_AGG<T, SumSqOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
(T)0.0, agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_col_sum_sq_d(double* g_idata, double* 
g_odata, uint rows, uint cols) {
-       reduce_col_sum_sq(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_col_sum_sq_f(float* g_idata, float* g_odata, 
uint rows, uint cols) {
-       reduce_col_sum_sq(g_idata, g_odata, rows, cols);
-}
-
-template<typename T>
-__device__ void reduce_row_sum_sq(T* g_idata, T* g_odata, uint rows, uint 
cols) {
-       SumSqOp<T> agg_op;
-       IdentityOp<T> spoof_op;
-       ROW_AGG<T, SumSqOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
(T)0.0, agg_op, spoof_op);
-}
-
-extern "C" __global__ void reduce_row_sum_sq_d(double* g_idata, double* 
g_odata, uint rows, uint cols) {
-       reduce_row_sum_sq(g_idata, g_odata, rows, cols);
-}
-
-extern "C" __global__ void reduce_row_sum_sq_f(float* g_idata, float* g_odata, 
uint rows, uint cols) {
-       reduce_row_sum_sq(g_idata, g_odata, rows, cols);
-}
+//template<typename T>
+//__device__ void reduce_col_sum_sq(T* g_idata, T* g_odata, uint rows, uint 
cols) {
+//     SumSqOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     COL_AGG<T, SumSqOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
(T)0.0, agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_col_sum_sq_d(double* g_idata, double* 
g_odata, uint rows, uint cols) {
+//     reduce_col_sum_sq(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_col_sum_sq_f(float* g_idata, float* 
g_odata, uint rows, uint cols) {
+//     reduce_col_sum_sq(g_idata, g_odata, rows, cols);
+//}
+
+//template<typename T>
+//__device__ void reduce_row_sum_sq(T* g_idata, T* g_odata, uint rows, uint 
cols) {
+//     SumSqOp<T> agg_op;
+//     IdentityOp<T> spoof_op;
+//     ROW_AGG<T, SumSqOp<T>, IdentityOp<T>>(g_idata, g_odata, rows, cols, 
(T)0.0, agg_op, spoof_op);
+//}
+//
+//extern "C" __global__ void reduce_row_sum_sq_d(double* g_idata, double* 
g_odata, uint rows, uint cols) {
+//     reduce_row_sum_sq(g_idata, g_odata, rows, cols);
+//}
+//
+//extern "C" __global__ void reduce_row_sum_sq_f(float* g_idata, float* 
g_odata, uint rows, uint cols) {
+//     reduce_row_sum_sq(g_idata, g_odata, rows, cols);
+//}
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
index 36299c7..d3f0778 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.cpp
@@ -45,26 +45,26 @@ size_t SpoofCUDAContext::initialize_cuda(uint32_t 
device_id, const char* resourc
   // SUM
   CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_d"));
   ctx->reduction_kernels.insert(std::make_pair("reduce_sum_d", func));
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_f"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_f", func));
-
-  // SUM_SQ
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_d"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_d", func));
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_f"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_f", func));
-
-  // MIN
+//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_f"));
+//  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_f", func));
+//
+//  // SUM_SQ
+//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_d"));
+//  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_d", func));
+//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_sum_sq_f"));
+//  ctx->reduction_kernels.insert(std::make_pair("reduce_sum_sq_f", func));
+//
+//  // MIN
   CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_d"));
   ctx->reduction_kernels.insert(std::make_pair("reduce_min_d", func));
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_f"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_min_f", func));
-
-  // MAX
+//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_min_f"));
+//  ctx->reduction_kernels.insert(std::make_pair("reduce_min_f", func));
+//
+//  // MAX
   CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_d"));
   ctx->reduction_kernels.insert(std::make_pair("reduce_max_d", func));
-  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_f"));
-  ctx->reduction_kernels.insert(std::make_pair("reduce_max_f", func));
+//  CHECK_CUDA(cuModuleGetFunction(&func, ctx->reductions, "reduce_max_f"));
+//  ctx->reduction_kernels.insert(std::make_pair("reduce_max_f", func));
 
   return reinterpret_cast<size_t>(ctx);
 }
@@ -95,47 +95,80 @@ bool SpoofCUDAContext::compile_cuda(const std::string &src,
   std::cout << "cuda_path: " << cuda_include_path << std::endl;
 #endif
 
-  SpoofOperator::AggType type = SpoofOperator::AggType::NONE;
-  SpoofOperator::AggOp op = SpoofOperator::AggOp::NONE;
-
-  auto pos = 0;
-  if((pos = src.find("CellType")) != std::string::npos) {
-      if(src.substr(pos, pos+30).find("FULL_AGG") != std::string::npos)
-          type = SpoofOperator::AggType::FULL_AGG;
-      else if(src.substr(pos, pos+30).find("ROW_AGG") != std::string::npos)
-          type = SpoofOperator::AggType::ROW_AGG;
-      else if(src.substr(pos, pos+30).find("COL_AGG") != std::string::npos)
-          type = SpoofOperator::AggType::COL_AGG;
-      else if(src.substr(pos, pos+30).find("NO_AGG") != std::string::npos)
-          type = SpoofOperator::AggType::NO_AGG;
-      else {
-          std::cerr << "error: unknown aggregation type" << std::endl;
-          return false;
-      }
-
-      if(type != SpoofOperator::AggType::NO_AGG) {
-          if((pos = src.find("AggOp")) != std::string::npos) {
-              if(src.substr(pos, pos+30).find("AggOp.SUM") != 
std::string::npos)
-                  op = SpoofOperator::AggOp::SUM;
-              else if(src.substr(pos, pos+30).find("AggOp.SUM_SQ") != 
std::string::npos)
-                  op = SpoofOperator::AggOp::SUM_SQ;
-              else if(src.substr(pos, pos+30).find("AggOp.MIN") != 
std::string::npos)
-                  op = SpoofOperator::AggOp::MIN;
-              else if(src.substr(pos, pos+30).find("AggOp.MAX") != 
std::string::npos)
-                  op = SpoofOperator::AggOp::MAX;
-              else {
-                std::cerr << "error: unknown aggregation operator" << 
std::endl;
-                return false;
-              }
-          }
-      }
-  }
-
-  std::stringstream s1, s2, s3;
-  s1 << "-I" << resource_path << "/cuda/headers";
-  s2 << "-I" << resource_path << "/cuda/spoof";
-
-  jitify::Program program = kernel_cache.program(src, 0, {s1.str(), s2.str(), 
cuda_include_path});
-  ops.insert(std::make_pair(name, SpoofOperator({std::move(program), type, 
op})));
-  return true;
+       SpoofOperator::AggType agg_type= SpoofOperator::AggType::NONE;
+       SpoofOperator::AggOp agg_op = SpoofOperator::AggOp::NONE;
+       SpoofOperator::OpType op_type = SpoofOperator::OpType::NONE;
+
+       auto pos = 0;
+       
+       if((pos = src.find("SpoofCellwiseOp")) != std::string::npos)
+               op_type = SpoofOperator::OpType::CW;
+    else if((pos = src.find("SpoofRowwiseOp")) != std::string::npos)
+               op_type = SpoofOperator::OpType::RA;
+       else {
+        std::cerr << "error: unknown spoof operator" << std::endl;
+               return false;
+       }
+
+       bool TB1 = false;
+       if((pos = src.find("TB1")) != std::string::npos)
+               if(src.substr(pos, pos+8).find("true") != std::string::npos)
+                       TB1 = true;
+
+       uint32_t numTempVect = 0;
+       if((pos = src.find("// VectMem: ")) != std::string::npos)
+               numTempVect = std::stoi(std::string(src.begin() + pos + 12, 
std::find(src.begin()+pos, src.end(), '\n')));
+
+       int32_t constDim2 = 0;
+       if((pos = src.find("// ConstDim2: ")) != std::string::npos)
+               constDim2 = std::stoi(std::string(src.begin() + pos + 14, 
std::find(src.begin()+pos, src.end(), '\n')));
+
+       bool sparse_safe = false;
+       if ((pos = src.find("// SparseSafe:")) != std::string::npos)
+               if (src.substr(pos, pos + 15).find("true") != std::string::npos)
+                       sparse_safe = true;
+
+       if(((pos = src.find("CellType")) != std::string::npos) || ((pos = 
src.find("RowType")) != std::string::npos)){
+               if(src.substr(pos, pos+30).find("FULL_AGG") != 
std::string::npos)
+                       agg_type= SpoofOperator::AggType::FULL_AGG;
+               else if(src.substr(pos, pos+30).find("ROW_AGG") != 
std::string::npos)
+                       agg_type= SpoofOperator::AggType::ROW_AGG;
+               else if(src.substr(pos, pos+30).find("COL_AGG") != 
std::string::npos)
+                       agg_type= SpoofOperator::AggType::COL_AGG;
+               else if(src.substr(pos, pos+30).find("NO_AGG") != 
std::string::npos)
+                       agg_type= SpoofOperator::AggType::NO_AGG;
+               else if(src.substr(pos, pos+30).find("NO_AGG_CONST") != 
std::string::npos)
+                       agg_type= SpoofOperator::AggType::NO_AGG_CONST;
+               else if(src.substr(pos, pos+30).find("COL_AGG_T") != 
std::string::npos)
+                       agg_type= SpoofOperator::AggType::COL_AGG_T;
+               else {
+                       std::cerr << "error: unknown aggregation type" << 
std::endl;
+                       return false;
+               }
+
+               if((agg_type!= SpoofOperator::AggType::NO_AGG) && (op_type == 
SpoofOperator::OpType::CW)) {
+                       if((pos = src.find("AggOp")) != std::string::npos) {
+                               if(src.substr(pos, pos+30).find("AggOp.SUM") != 
std::string::npos)
+                                       agg_op = SpoofOperator::AggOp::SUM;
+                               else if(src.substr(pos, 
pos+30).find("AggOp.SUM_SQ") != std::string::npos)
+                                       agg_op = SpoofOperator::AggOp::SUM_SQ;
+                               else if(src.substr(pos, 
pos+30).find("AggOp.MIN") != std::string::npos)
+                                       agg_op = SpoofOperator::AggOp::MIN;
+                               else if(src.substr(pos, 
pos+30).find("AggOp.MAX") != std::string::npos)
+                                       agg_op = SpoofOperator::AggOp::MAX;
+                               else {
+                               std::cerr << "error: unknown aggregation 
operator" << std::endl;
+                                       return false;
+                               }
+                       }
+               }
+       }
+
+       std::stringstream s1, s2, s3;
+       s1 << "-I" << resource_path << "/cuda/headers";
+       s2 << "-I" << resource_path << "/cuda/spoof";
+
+       jitify::Program program = kernel_cache.program(src, 0, {s1.str(), 
s2.str(), cuda_include_path});
+       ops.insert(std::make_pair(name, SpoofOperator({std::move(program), 
agg_type, agg_op, op_type, name, TB1, constDim2, numTempVect, sparse_safe})));
+       return true;
 }
diff --git a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h 
b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
index 36d29ec..fec38bc 100644
--- a/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
+++ b/src/main/cuda/spoof-launcher/SpoofCUDAContext.h
@@ -21,14 +21,22 @@
 #ifndef SPOOFCUDACONTEXT_H
 #define SPOOFCUDACONTEXT_H
 
+#define NOMINMAX
+
 #include <cmath>
 #include <cstdint>
 #include <map>
 #include <string>
+#include <algorithm>
 
-#ifdef __DEBUG
-    #define JITIFY_PRINT_ALL 1
+#include "Matrix.h"
+
+#ifdef _DEBUG
+#define __DEBUG
 #endif
+// #ifdef __DEBUG
+    // #define JITIFY_PRINT_ALL 1
+// #endif
 
 #include <jitify.hpp>
 
@@ -37,13 +45,19 @@
 using jitify::reflection::type_of;
 
 struct SpoofOperator {
-  enum class AggType : int { NO_AGG, ROW_AGG, COL_AGG, FULL_AGG, NONE };
-  enum class AggOp : int {SUM, SUM_SQ, MIN, MAX, NONE };
-
-  jitify::Program program;
-  AggType agg_type;
-  AggOp agg_op;
-
+       enum class AggType : int { NO_AGG, NO_AGG_CONST, ROW_AGG, COL_AGG, 
FULL_AGG, COL_AGG_T, NONE };
+       enum class AggOp : int {SUM, SUM_SQ, MIN, MAX, NONE };
+       enum class OpType : int { CW, RA, MA, OP, NONE };
+       
+       jitify::Program program;
+       AggType agg_type;
+       AggOp agg_op;
+       OpType op_type;
+    const std::string name;
+       bool TB1 = false;
+       int32_t const_dim2;
+       uint32_t num_temp_vectors;
+       bool sparse_safe = true;
 };
 
 class SpoofCUDAContext {
@@ -71,199 +85,338 @@ public:
 
   bool compile_cuda(const std::string &src, const std::string &name);
 
-  template <typename T>
-  T execute_kernel(const std::string &name, T **in_ptrs, int num_inputs,
-                   T **side_ptrs, int num_sides, T *out_ptr, T *scalars_ptr,
-                   int num_scalars, int m, int n, int grix) {
-
-    T result = 0.0;
-    size_t dev_buf_size;
-    T **d_sides = nullptr;
-    T *d_scalars = nullptr;
-    T *d_temp_agg_buf;
-    uint32_t N = m * n;
-
-    auto o = ops.find(name);
-    if (o != ops.end()) {
-      SpoofOperator *op = &(o->second);
-
-      if (num_sides > 0) {
-        dev_buf_size = sizeof(T *) * num_sides;
-        CHECK_CUDART(cudaMalloc((void **)&d_sides, dev_buf_size));
-        CHECK_CUDART(cudaMemcpy(d_sides, side_ptrs, dev_buf_size, 
cudaMemcpyHostToDevice));
-      }
-
-      if (num_scalars > 0) {
-        dev_buf_size = sizeof(T) * num_scalars;
-        CHECK_CUDART(cudaMalloc((void **)&d_scalars, dev_buf_size));
-        CHECK_CUDART(cudaMemcpy(d_scalars, scalars_ptr, dev_buf_size, 
cudaMemcpyHostToDevice));
-      }
-
-      switch (op->agg_type) {
-          case SpoofOperator::AggType::FULL_AGG: {
-            // num ctas
-            int NB = std::ceil((N + NT * 2 - 1) / (NT * 2));
-            dim3 grid(NB, 1, 1);
-            dim3 block(NT, 1, 1);
-            unsigned int shared_mem_size = NT * sizeof(T);
-
-            dev_buf_size = sizeof(T) * NB;
-            CHECK_CUDART(cudaMalloc((void **)&d_temp_agg_buf, dev_buf_size));
+       template <typename T>
+       T execute_kernel(const std::string &name, std::vector<Matrix<T>>& 
input,  std::vector<Matrix<T>>& sides, Matrix<T>* output,
+                                        T *scalars_ptr, uint32_t num_scalars, 
uint32_t grix) {
+               
+               T result = 0.0;
+               size_t dev_buf_size;
+               Matrix<T>* d_in = nullptr;
+               Matrix<T>* d_out = nullptr;
+               Matrix<T>* d_sides = nullptr;
+               T* b1_transposed = nullptr;
+               T *d_scalars = nullptr;
+
+               auto o = ops.find(name);
+               if (o != ops.end()) {
+                       SpoofOperator *op = &(o->second);
+
+                       // ToDo: multiple inputs for SpoofOuterProduct template
+                       CHECK_CUDART(cudaMalloc((void **)&d_in, 
sizeof(Matrix<T>)));
+                       CHECK_CUDART(cudaMemcpy(d_in, 
reinterpret_cast<void*>(&input[0]), sizeof(Matrix<T>), cudaMemcpyHostToDevice));
+
+                       if(output != nullptr) {
+                               if (op->sparse_safe == true && 
input.front().row_ptr != nullptr) {
 #ifdef __DEBUG
-            // ToDo: connect output to SystemDS logging facilities
-            std::cout << "launching spoof cellwise kernel " << name << " with "
-                      << NT * NB << " threads in " << NB << " blocks and "
-                      << shared_mem_size
-                      << " bytes of shared memory for full aggregation of "
-                      << N << " elements"
-                      << std::endl;
+                                       std::cout << "copying sparse safe row 
ptrs" << std::endl;
 #endif
-            CHECK_CUDA(op->program.kernel(name)
-                .instantiate(type_of(result))
-                .configure(grid, block, shared_mem_size)
-                .launch(in_ptrs[0], d_sides, d_temp_agg_buf, d_scalars, m, n, 
grix));
-
-            if(NB > 1) {
-                std::string reduction_kernel_name = 
determine_agg_kernel<T>(op);
-
-                CUfunction reduce_kernel = 
reduction_kernels.find(reduction_kernel_name)->second;
-                N = NB;
-                int iter = 1;
-                while (NB > 1) {
-                    void* args[3] = { &d_temp_agg_buf, &d_temp_agg_buf, &N};
-
-                    NB = std::ceil((N + NT * 2 - 1) / (NT * 2));
+                                       
CHECK_CUDART(cudaMemcpy(output->row_ptr, input.front().row_ptr, 
(input.front().rows+1)*sizeof(uint32_t), cudaMemcpyDeviceToDevice));
+                               }
+
+                               CHECK_CUDART(cudaMalloc((void **) &d_out, 
sizeof(Matrix<T>)));
+                               //CHECK_CUDART(cudaMemset(out->data, 0, 
out->rows*out->cols*sizeof(T)));
+                               CHECK_CUDART(cudaMemcpy(d_out, 
reinterpret_cast<void *>(output), sizeof(Matrix<T>),
+                                               cudaMemcpyHostToDevice));
+
+                       }
+                       else {
+                               uint32_t num_blocks = 1;
+                               if (op->op_type == SpoofOperator::OpType::CW)
+                                       num_blocks = 
std::ceil(((input.front().rows * input.front().cols) + NT * 2 - 1) / (NT * 2));
+                               
+                               CHECK_CUDART(cudaMalloc((void **) &d_out, 
sizeof(Matrix<T>)));
+                               T* d_out_data = nullptr;
+                               CHECK_CUDART(cudaMalloc((void **) &d_out_data, 
sizeof(T) * num_blocks));
+                               Matrix<T> agg_out{d_out_data, 0, 0, num_blocks, 
1, num_blocks};
+                               CHECK_CUDART(cudaMemcpy(d_out, 
reinterpret_cast<void *>(&agg_out), sizeof(Matrix<T>),
+                                                                               
cudaMemcpyHostToDevice));
+                       }
+                       
+                       if (!sides.empty()) {
+                               if(op->TB1) {
 #ifdef __DEBUG
-                    std::cout << "agg iter " << iter++ << " launching spoof 
cellwise kernel " << name << " with "
+                                       std::cout << "transposing TB1 for " << 
op->name << std::endl;
+#endif
+                                       T* b1 = sides[0].data;
+                                       uint32_t m = sides[0].rows;
+                                       uint32_t n = sides[0].cols;
+                                       
+                                       
cudaMalloc(reinterpret_cast<void**>(&b1_transposed), sizeof(T) * m * n);
+                                       double alpha = 1.0;
+                                       double beta  = 0.0;
+                                       cublasHandle_t handle;
+                                       
+                                       CHECK_CUBLAS(cublasCreate(&handle));
+                                       CHECK_CUBLAS(cublasDgeam(handle, 
CUBLAS_OP_T, CUBLAS_OP_T, m, n, &alpha, b1, n, &beta, b1, n, b1_transposed, m));
+                                       sides[0].data = b1_transposed;
+                                       sides[0].rows = n;
+                                       sides[0].cols = m;
+                                       CHECK_CUBLAS(cublasDestroy(handle));
+                               }
+                               dev_buf_size = sizeof(Matrix<T>) * sides.size();
+                               CHECK_CUDART(cudaMalloc(reinterpret_cast<void 
**>(&d_sides), dev_buf_size));
+                               CHECK_CUDART(cudaMemcpy(d_sides, &sides[0], 
dev_buf_size, cudaMemcpyHostToDevice));
+                       }
+
+                       if (num_scalars > 0) {
+                               dev_buf_size = sizeof(T) * num_scalars;
+                               CHECK_CUDART(cudaMalloc((void **)&d_scalars, 
dev_buf_size));
+                               CHECK_CUDART(cudaMemcpy(d_scalars, scalars_ptr, 
dev_buf_size, cudaMemcpyHostToDevice));
+                       }
+                       
+                   switch(op->op_type) {
+                   case SpoofOperator::OpType::CW:
+                               launch_cw_kernel(op, d_in, d_out, d_sides, 
sides.size(), d_scalars, input.front().rows,
+                                        input.front().cols, grix, input[0]);
+                               break;
+                   case SpoofOperator::OpType::RA:
+                               launch_ra_kernel(op, d_in, d_out, d_sides, 
sides.size(), d_scalars, input.front().rows,
+                                                                
input.front().cols, grix, input[0].row_ptr!=nullptr);
+                       break;
+                   default:
+                               std::cerr << "error: unknown spoof operator" << 
std::endl;
+                       return result;
+                   }
+                       
+                       if (num_scalars > 0)
+                               CHECK_CUDART(cudaFree(d_scalars));
+                       
+                       if (sides.size() > 0)
+                               CHECK_CUDART(cudaFree(d_sides));
+
+                       if (op->TB1)
+                               CHECK_CUDART(cudaFree(b1_transposed));
+                       
+                       if (op->agg_type == SpoofOperator::AggType::FULL_AGG) {
+//                             std::cout << "retrieving scalar result" << 
std::endl;
+                               
+                               Matrix<T> res_mat;
+                               CHECK_CUDART(cudaMemcpy(&res_mat, d_out, 
sizeof(Matrix<T>), cudaMemcpyDeviceToHost));
+                               CHECK_CUDART(cudaMemcpy(&result, res_mat.data, 
sizeof(T), cudaMemcpyDeviceToHost));
+
+                               CHECK_CUDART(cudaFree(res_mat.data));
+                               CHECK_CUDART(cudaFree(d_out));
+                       }
+               } 
+               else {
+                       std::cerr << "kernel " << name << " not found." << 
std::endl;
+                       return result;
+               }
+               return result;
+       }
+
+       template<typename T>
+       std::string determine_agg_kernel(SpoofOperator* op) {
+         std::string reduction_kernel_name;
+         std::string reduction_type;
+         std::string suffix = (typeid(T) == typeid(double) ? "_d" : "_f");
+         switch (op->agg_type) {
+         case SpoofOperator::AggType::FULL_AGG:
+                 reduction_type = "_";
+                 break;
+         case SpoofOperator::AggType::ROW_AGG:
+                 reduction_type = "_row_";
+                 break;
+         case SpoofOperator::AggType::COL_AGG:
+                 reduction_type = "_col_";
+                 break;
+         default:
+                 std::cerr << "unknown reduction type" << std::endl;
+                 return "";
+         }
+       
+         switch (op->agg_op) {
+         case SpoofOperator::AggOp::MIN:
+                 reduction_kernel_name = "reduce" + reduction_type + "min" + 
suffix;
+                 break;
+         case SpoofOperator::AggOp::MAX:
+                 reduction_kernel_name = "reduce" + reduction_type + "max" + 
suffix;
+                 break;
+         case SpoofOperator::AggOp::SUM_SQ:
+                 reduction_kernel_name = "reduce" + reduction_type + "sum_sq" 
+ suffix;
+                 break;
+         case SpoofOperator::AggOp::SUM:
+                 reduction_kernel_name = "reduce" + reduction_type + "sum" + 
suffix;
+                 break;
+         default:
+                 std::cerr << "unknown reduction op" << std::endl;
+                 return "";
+         }
+       
+         return reduction_kernel_name;
+       }
+
+       template<typename T>
+       void launch_cw_kernel(SpoofOperator* op, Matrix<T>* d_in, Matrix<T>* 
d_out, Matrix<T>* d_sides, uint32_t num_sides,
+                                          T* d_scalars, uint32_t in_rows, 
uint32_t in_cols, uint32_t grix, const Matrix<T>& h_in) {
+               T value_type;
+               bool sparse = h_in.row_ptr != nullptr;
+               uint32_t N = in_rows * in_cols;
+               std::string op_name(op->name + "_DENSE");
+               if(sparse) {
+                       op_name = std::string(op->name + "_SPARSE");
+                       N = h_in.nnz;
+               }
+               
+               switch (op->agg_type) {
+               case SpoofOperator::AggType::FULL_AGG: {
+                       // num ctas
+                       uint32_t NB = std::ceil((N + NT * 2 - 1) / (NT * 2));
+                       dim3 grid(NB, 1, 1);
+                       dim3 block(NT, 1, 1);
+                       uint32_t shared_mem_size = NT * sizeof(T);
+
+//#ifdef __DEBUG
+                               // ToDo: connect output to SystemDS logging 
facilities
+                               std::cout << "launching spoof cellwise kernel " 
<< op_name << " with "
+                                                 << NT * NB << " threads in " 
<< NB << " blocks and "
+                                                 << shared_mem_size
+                                                 << " bytes of shared memory 
for full aggregation of "
+                                                 << N << " elements"
+                                                 << std::endl;
+//#endif
+                       
+                       CHECK_CUDA(op->program.kernel(op_name)
+                                                          
.instantiate(type_of(value_type), std::max(1u, num_sides))
+                                                          .configure(grid, 
block, shared_mem_size)
+                                                          .launch(d_in, 
d_sides, d_out, d_scalars, N, grix));
+                       
+                       if(NB > 1) {
+                               std::string reduction_kernel_name = 
determine_agg_kernel<T>(op);
+
+                               CUfunction reduce_kernel = 
reduction_kernels.find(reduction_kernel_name)->second;
+                               N = NB;
+                               uint32_t iter = 1;
+                               while (NB > 1) {
+                                       void* args[3] = { &d_out, &d_out, &N};
+
+                                       NB = std::ceil((N + NT * 2 - 1) / (NT * 
2));
+//#ifdef __DEBUG
+                                       std::cout << "agg iter " << iter++ << " 
launching spoof cellwise kernel " << op_name << " with "
                     << NT * NB << " threads in " << NB << " blocks and "
                     << shared_mem_size
                     << " bytes of shared memory for full aggregation of "
                     << N << " elements"
                     << std::endl;
-#endif
-                    CHECK_CUDA(cuLaunchKernel(reduce_kernel, 
-                        NB, 1, 1, 
-                        NT, 1, 1,
-                        shared_mem_size, 0, args, 0));
-                    N = NB;
-                }
-            }
-                            
-            CHECK_CUDART(cudaMemcpy(&result, d_temp_agg_buf, sizeof(T), 
cudaMemcpyDeviceToHost));
-            CHECK_CUDART(cudaFree(d_temp_agg_buf));
-            break;
-          }
-          case SpoofOperator::AggType::COL_AGG: {
-              // num ctas
-              int NB = std::ceil((N + NT - 1) / NT);
-              dim3 grid(NB, 1, 1);
-              dim3 block(NT, 1, 1);
-              unsigned int shared_mem_size = 0;
+//#endif
+                                       CHECK_CUDA(cuLaunchKernel(reduce_kernel,
+                                                       NB, 1, 1,
+                                                       NT, 1, 1,
+                                                       shared_mem_size, 0, 
args, 0));
+                                                       N = NB;
+                               }
+                       }
+                       break;
+               }
+               case SpoofOperator::AggType::COL_AGG: {
+                       // num ctas
+                       uint32_t NB = std::ceil((N + NT - 1) / NT);
+                       dim3 grid(NB, 1, 1);
+                       dim3 block(NT, 1, 1);
+                       uint32_t shared_mem_size = 0;
 #ifdef __DEBUG
-              std::cout << " launching spoof cellwise kernel " << name << " 
with "
-                  << NT * NB << " threads in " << NB << " blocks for column 
aggregation of "
-                  << N << " elements" << std::endl;
+                       std::cout << " launching spoof cellwise kernel " << 
op_name << " with "
+                                       << NT * NB << " threads in " << NB << " 
blocks for column aggregation of "
+                                       << N << " elements" << std::endl;
 #endif
-              CHECK_CUDA(op->program.kernel(name)
-                  .instantiate(type_of(result))
-                  .configure(grid, block)
-                  .launch(in_ptrs[0], d_sides, out_ptr, d_scalars, m, n, 
grix));
-
-              break;
-          }
-          case SpoofOperator::AggType::ROW_AGG: {
-              // num ctas
-              int NB = m;
-              dim3 grid(NB, 1, 1);
-              dim3 block(NT, 1, 1);
-              unsigned int shared_mem_size = NT * sizeof(T);
+                       CHECK_CUDA(op->program.kernel(op_name)
+                                                          
.instantiate(type_of(value_type), std::max(1u, num_sides))
+                                                          .configure(grid, 
block, shared_mem_size)
+                                                          .launch(d_in, 
d_sides, d_out, d_scalars, N, grix));
+
+                       break;
+               }
+               case SpoofOperator::AggType::ROW_AGG: {
+                       // num ctas
+                       uint32_t NB = in_rows;
+                       dim3 grid(NB, 1, 1);
+                       dim3 block(NT, 1, 1);
+                       uint32_t shared_mem_size = NT * sizeof(T);
+//#ifdef __DEBUG
+                       std::cout << " launching spoof cellwise kernel " << 
op_name << " with "
+                                       << NT * NB << " threads in " << NB << " 
blocks and "
+                                       << shared_mem_size << " bytes of shared 
memory for row aggregation of "
+                                       << N << " elements" << std::endl;
+//#endif
+                       CHECK_CUDA(op->program.kernel(op_name)
+                                                          
.instantiate(type_of(value_type), std::max(1u, num_sides))
+                                                          .configure(grid, 
block, shared_mem_size)
+                                                          .launch(d_in, 
d_sides, d_out, d_scalars, N, grix));
+
+                       break;
+               }
+               case SpoofOperator::AggType::NO_AGG: 
+               default: {
+                       // num ctas
+                       // ToDo: VT not a template parameter anymore
+                       uint32_t NB = std::ceil((N + NT * VT - 1) / (NT * VT));
+                       if(sparse)
+                               NB = in_rows;
+                       dim3 grid(NB, 1, 1);
+                       dim3 block(NT, 1, 1);
+                       uint32_t shared_mem_size = 0;
 
 #ifdef __DEBUG
-              std::cout << " launching spoof cellwise kernel " << name << " 
with "
-                  << NT * NB << " threads in " << NB << " blocks and "
-                  << shared_mem_size << " bytes of shared memory for row 
aggregation of "
-                  << N << " elements" << std::endl;
-#endif
-              CHECK_CUDA(op->program.kernel(name)
-                  .instantiate(type_of(result))
-                  .configure(grid, block, shared_mem_size)
-                  .launch(in_ptrs[0], d_sides, out_ptr, d_scalars, m, n, 
grix));
-
-              break;
-          }
-          case SpoofOperator::AggType::NO_AGG: 
-          default: {
-            // num ctas
-              // ToDo: VT not a template parameter anymore
-            int NB = std::ceil((N + NT * VT - 1) / (NT * VT));
-            dim3 grid(NB, 1, 1);
-            dim3 block(NT, 1, 1);
-#ifdef __DEBUG
-            std::cout << "launching spoof cellwise kernel " << name << " with 
" << NT * NB
-                      << " threads in " << NB << " blocks without aggregation 
for " 
-                      << N << " elements"
-                      << std::endl;
+                       if(sparse) {
+                               std::cout << "launching sparse spoof cellwise 
kernel " << op_name << " with " << NT * NB
+                                                 << " threads in " << NB << " 
blocks without aggregation for " << N << " elements"
+                                                 << std::endl;
+                       }
+                       else {
+                               std::cout << "launching spoof cellwise kernel " 
<< op_name << " with " << NT * NB
+                                                 << " threads in " << NB << " 
blocks without aggregation for " << N << " elements"
+                                                 << std::endl;
+                       }
 #endif
-            CHECK_CUDA(op->program.kernel(name)
-                .instantiate(type_of(result))
-                .configure(grid, block)
-                .launch(in_ptrs[0], d_sides, out_ptr, d_scalars, m, n, grix));
-          }
-      }
-      
-      if (num_scalars > 0)
-        CHECK_CUDART(cudaFree(d_scalars));
-
-      if (num_sides > 0)
-        CHECK_CUDART(cudaFree(d_sides));
-    } 
-    else {
-      std::cerr << "kernel " << name << " not found." << std::endl;
-      return result;
-    }
-    return result;
-  }
-
-  template<typename T>
-  std::string determine_agg_kernel(SpoofOperator* op) {
-      std::string reduction_kernel_name;
-      std::string reduction_type;
-      std::string suffix = (typeid(T) == typeid(double) ? "_d" : "_f");
-      switch (op->agg_type) {
-      case SpoofOperator::AggType::FULL_AGG:
-          reduction_type = "_";
-          break;
-      case SpoofOperator::AggType::ROW_AGG:
-          reduction_type = "_row_";
-          break;
-      case SpoofOperator::AggType::COL_AGG:
-          reduction_type = "_col_";
-          break;
-      default:
-          std::cerr << "unknown reduction type" << std::endl;
-          return "";
-      }
-    
-      switch (op->agg_op) {
-      case SpoofOperator::AggOp::MIN:
-          reduction_kernel_name = "reduce" + reduction_type + "min" + suffix;
-          break;
-      case SpoofOperator::AggOp::MAX:
-          reduction_kernel_name = "reduce" + reduction_type + "max" + suffix;
-          break;
-      case SpoofOperator::AggOp::SUM_SQ:
-          reduction_kernel_name = "reduce" + reduction_type + "sum_sq" + 
suffix;
-          break;
-      case SpoofOperator::AggOp::SUM:
-          reduction_kernel_name = "reduce" + reduction_type + "sum" + suffix;
-          break;
-      default:
-          std::cerr << "unknown reduction op" << std::endl;
-          return "";
-      }
-
-      return reduction_kernel_name;
-  }
+//                     CHECK_CUDA(op->program.kernel(op_name)
+//                                     .instantiate(type_of(result))
+//                                     .configure(grid, block)
+//                                     .launch(in_ptrs[0], d_sides, out_ptr, 
d_scalars, m, n, grix));
+
+                       CHECK_CUDA(op->program.kernel(op_name)
+                                                          
.instantiate(type_of(value_type), std::max(1u, num_sides))
+                                                          .configure(grid, 
block, shared_mem_size)
+                                                          .launch(d_in, 
d_sides, d_out, d_scalars, N, grix));
+               }
+               }
+       }
+
+       template<typename T>
+       void launch_ra_kernel(SpoofOperator* op, Matrix<T>* d_in, Matrix<T>* 
d_out, Matrix<T>* d_sides, uint32_t num_sides,
+                       T* d_scalars, uint32_t in_rows, uint32_t in_cols, 
uint32_t grix, bool sparse) {
+               T value_type;
+               dim3 grid(in_rows, 1, 1);
+               dim3 block(NT, 1, 1);
+               unsigned int shared_mem_size = NT * sizeof(T);
+
+               uint32_t tmp_len = 0;
+               uint32_t temp_buf_size = 0;
+               T* d_temp = nullptr;
+               if(op->num_temp_vectors>0) {
+                       tmp_len = std::max(in_cols, op->const_dim2 < 0 ? 0 : 
static_cast<uint32_t>(op->const_dim2));
+                       temp_buf_size = op->num_temp_vectors * tmp_len * 
in_rows * sizeof(T);
+                       
CHECK_CUDART(cudaMalloc(reinterpret_cast<void**>(&d_temp), temp_buf_size));
+                       CHECK_CUDART(cudaMemset(d_temp, 0, temp_buf_size));
+               }
+
+               std::string name(op->name + "_DENSE");
+               if(sparse)
+                       name = std::string(op->name + "_SPARSE");
+
+//#ifdef __DEBUG
+                       // ToDo: connect output to SystemDS logging facilities
+                       std::cout << "launching spoof rowwise kernel " << name 
<< " with " << NT * in_rows << " threads in " << in_rows
+                               << " blocks and " << shared_mem_size << " bytes 
of shared memory for " << in_cols << " cols processed by "
+                               << NT << " threads per row, adding " << 
temp_buf_size / 1024 << " kb of temp buffer in global memory." <<  std::endl;
+//#endif
+               CHECK_CUDA(op->program.kernel(name)
+                               .instantiate(type_of(value_type), std::max(1u, 
num_sides), op->num_temp_vectors, tmp_len)
+                               .configure(grid, block, shared_mem_size)
+                               .launch(d_in, d_sides, d_out, d_scalars, 
d_temp, grix));
+
+               if(op->num_temp_vectors>0)
+                       CHECK_CUDART(cudaFree(d_temp));
+       }
 };
 
 #endif // SPOOFCUDACONTEXT_H
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.cpp 
b/src/main/cuda/spoof-launcher/jni_bridge.cpp
index 6645003..86d0a1f 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.cpp
+++ b/src/main/cuda/spoof-launcher/jni_bridge.cpp
@@ -19,6 +19,7 @@
 
 #include "jni_bridge.h"
 #include "SpoofCUDAContext.h"
+#include "Matrix.h"
 
 // JNI Methods to get/release arrays
 #define GET_ARRAY(env, input)                                                  
\
@@ -58,52 +59,144 @@ 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_compile_1cuda_1kernel(
 
 JNIEXPORT jdouble JNICALL
 Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1d(
-    JNIEnv *env, jobject jobj, jlong ctx, jstring name, jlongArray in_ptrs,
-    jlongArray side_ptrs, jlong out_ptr, jdoubleArray scalars_, jlong m, jlong 
n, jlong grix) {
-
-  SpoofCUDAContext *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
-  const char *cstr_name = env->GetStringUTFChars(name, NULL);
-
-  double **inputs = reinterpret_cast<double **>(GET_ARRAY(env, in_ptrs));
-  double **sides = reinterpret_cast<double **>(GET_ARRAY(env, side_ptrs));
-  double *scalars = reinterpret_cast<double *>(GET_ARRAY(env, scalars_));
-
-  double result = ctx_->execute_kernel(
-      cstr_name, inputs, env->GetArrayLength(in_ptrs), sides, 
env->GetArrayLength(side_ptrs),
-      reinterpret_cast<double*>(out_ptr), scalars, 
env->GetArrayLength(scalars_), m, n, grix);
-
-  RELEASE_ARRAY(env, in_ptrs, inputs);
-  RELEASE_ARRAY(env, side_ptrs, sides);
-  RELEASE_ARRAY(env, scalars_, scalars);
-
-  // FIXME: that release causes an error
-  //std::cout << "releasing " << name_ << std::endl;
-  env->ReleaseStringUTFChars(name, cstr_name);
-  return result;
+    JNIEnv *env, jobject jobj, jlong ctx, jstring name, jlongArray in_ptrs, 
jint input_offset, jlongArray side_ptrs,
+               jlongArray out_ptrs, jdoubleArray scalars_, jlong grix, jobject 
inputs_, jobject out_obj) {
+       
+       SpoofCUDAContext *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
+       const char *cstr_name = env->GetStringUTFChars(name, NULL);
+       
+       size_t* inputs = reinterpret_cast<size_t*>(GET_ARRAY(env, in_ptrs));
+       size_t* sides = reinterpret_cast<size_t*>(GET_ARRAY(env, side_ptrs));
+       size_t *output = reinterpret_cast<size_t*>(GET_ARRAY(env, out_ptrs));
+       double *scalars = reinterpret_cast<double*>(GET_ARRAY(env, scalars_));
+       
+       //ToDo: call once while init
+       jclass CacheableData = 
env->FindClass("org/apache/sysds/runtime/controlprogram/caching/CacheableData");
+       if(!CacheableData) {
+               std::cerr << " JNIEnv -> FindClass(CacheableData) failed" << 
std::endl;
+               return -1.0;
+       }
+       jclass ArrayList = env->FindClass("java/util/ArrayList");
+       if(!ArrayList) {
+               std::cerr << " JNIEnv -> FindClass(ArrayList) failed" << 
std::endl;
+               return -1.0;
+       }
+       jmethodID mat_obj_num_rows = env->GetMethodID(CacheableData, 
"getNumRows", "()J");
+       if(!mat_obj_num_rows) {
+               std::cerr << " JNIEnv -> GetMethodID() failed" << std::endl;
+               return -1.0;
+       }
+       jmethodID mat_obj_num_cols = env->GetMethodID(CacheableData, 
"getNumColumns", "()J");
+       if(!mat_obj_num_cols) {
+               std::cerr << " JNIEnv -> GetMethodID() failed" << std::endl;
+               return -1.0;
+       }
+       jmethodID ArrayList_size = env->GetMethodID(ArrayList, "size", "()I");
+       jmethodID ArrayList_get = env->GetMethodID(ArrayList, "get", 
"(I)Ljava/lang/Object;");
+
+       std::vector<Matrix<double>> in;
+       jint num_inputs = env->CallIntMethod(inputs_, ArrayList_size);
+#ifdef __DEBUG
+       std::cout << "num inputs: " << num_inputs << " offsets: " << 
input_offset << std::endl;
+#endif
+
+       for(auto ptr_idx = 0, input_idx = 0; input_idx < input_offset; 
ptr_idx+=4, input_idx++) {
+               jobject input_obj = env->CallObjectMethod(inputs_, 
ArrayList_get, input_idx);
+               uint32_t m = 
static_cast<uint32_t>(env->CallIntMethod(input_obj, mat_obj_num_rows));
+               uint32_t n = 
static_cast<uint32_t>(env->CallIntMethod(input_obj, mat_obj_num_cols));
+
+               
in.push_back(Matrix<double>{reinterpret_cast<double*>(inputs[ptr_idx+3]), 
reinterpret_cast<uint32_t*>(inputs[ptr_idx+1]),
+                                                                               
   reinterpret_cast<uint32_t*>(inputs[ptr_idx+2]), m, n,
+                                                                               
   static_cast<uint32_t>(inputs[ptr_idx])});
+#ifdef __DEBUG
+               std::cout << "input #" << input_idx << " m=" << m << " n=" << n 
<< std::endl;
+#endif
+       }
+
+       std::vector<Matrix<double>> side_inputs;
+       for(uint32_t ptr_idx = 0, input_idx = input_offset; input_idx < 
num_inputs; ptr_idx+=4, input_idx++) {
+               jobject side_input_obj = env->CallObjectMethod(inputs_, 
ArrayList_get, input_idx);
+               uint32_t m = 
static_cast<uint32_t>(env->CallIntMethod(side_input_obj, mat_obj_num_rows));
+               uint32_t n = 
static_cast<uint32_t>(env->CallIntMethod(side_input_obj, mat_obj_num_cols));
+
+//             std::cout << "sides["<<i << "]=" <<  sides[i] << std::endl;
+//             std::cout << "sides["<<i+1 << "]=" <<  sides[i+1] << std::endl;
+//             std::cout << "sides["<<i+2 << "]=" <<  sides[i+2] << std::endl;
+//             std::cout << "sides["<<i+3 << "]=" <<  sides[i+3] << std::endl;
+
+               
side_inputs.push_back(Matrix<double>{reinterpret_cast<double*>(sides[ptr_idx+3]),
 reinterpret_cast<uint32_t*>(sides[ptr_idx+1]),
+                                                                       
reinterpret_cast<uint32_t*>(sides[ptr_idx+2]), m, n,
+                                                                       
static_cast<uint32_t>(sides[ptr_idx])});
+               
+               uint32_t* row_ptr = 
reinterpret_cast<uint32_t*>(sides[ptr_idx+1]);
+               uint32_t* col_idxs = 
reinterpret_cast<uint32_t*>(sides[ptr_idx+2]);
+               double* data_ptr = reinterpret_cast<double*>(sides[ptr_idx+3]);
+#ifdef __DEBUG
+               if(row_ptr != nullptr) {
+                       for (auto i = 0; i < 2/*m*/; ++i) {
+                               uint32_t alen = row_ptr[i+1] - row_ptr[i];
+                               std::cout << "row_start=" << row_ptr[i] << " 
row_end=" << row_ptr[i+1] << " alen[" << i << "]=" << alen << std::endl << " 
col_idxs:\n";
+                               for (auto j = 0; j < alen; ++j) {
+                                       std::cout << " " << *col_idxs;
+                               }
+                               std::cout << "\ndata:" << std::endl;
+                               for (auto j = 0; j < alen; ++j) {
+                                       std::cout << " " << *data_ptr;
+                               }
+                               std::cout << std::endl;
+                       }
+               }
+               std::cout << "side input #" << input_idx << " m=" << m << " n=" 
<< n << std::endl;
+#endif
+       }
+
+       std::unique_ptr<Matrix<double>> out;
+       if(out_obj != nullptr) {
+//             std::cout << "out not null" << std::endl;
+               out = 
std::make_unique<Matrix<double>>(Matrix<double>{reinterpret_cast<double*>(output[3]),
+                                                                               
                                reinterpret_cast<uint32_t*>(output[1]),
+                                                                               
                                        reinterpret_cast<uint32_t*>(output[2]),
+                                                                               
                                          
static_cast<uint32_t>(env->CallIntMethod(out_obj, mat_obj_num_rows)),
+                                                                               
                                          
static_cast<uint32_t>(env->CallIntMethod(out_obj, mat_obj_num_cols)),
+                                                                               
                                          static_cast<uint32_t>(output[0])});
+       }
+
+       double result = ctx_->execute_kernel(cstr_name, in, side_inputs, 
out.get(), scalars,
+                                                                         
env->GetArrayLength(scalars_), grix);
+
+       RELEASE_ARRAY(env, in_ptrs, inputs);
+       RELEASE_ARRAY(env, side_ptrs, sides);
+       RELEASE_ARRAY(env, out_ptrs, output);
+       RELEASE_ARRAY(env, scalars_, scalars);
+
+       // FIXME: that release causes an error
+       //std::cout << "releasing " << name_ << std::endl;
+       env->ReleaseStringUTFChars(name, cstr_name);
+       return result;
 }
 
-JNIEXPORT jfloat JNICALL
-Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1f(
-    JNIEnv *env, jobject jobj, jlong ctx, jstring name, jlongArray in_ptrs,
-    jlongArray side_ptrs, jlong out_ptr, jfloatArray scalars_, jlong m, jlong 
n, jlong grix) {
-
-  SpoofCUDAContext *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
-
-  const char *cstr_name = env->GetStringUTFChars(name, NULL);
-
-  float **inputs = reinterpret_cast<float**>(GET_ARRAY(env, in_ptrs));
-  float **sides = reinterpret_cast<float **>(GET_ARRAY(env, side_ptrs));
-  float *scalars = reinterpret_cast<float *>(GET_ARRAY(env, scalars_));
-
-  float result = ctx_->execute_kernel(
-      cstr_name, inputs, env->GetArrayLength(in_ptrs), sides, 
env->GetArrayLength(side_ptrs),
-      reinterpret_cast<float *>(out_ptr), scalars, 
env->GetArrayLength(scalars_), m, n, grix);
-
-  RELEASE_ARRAY(env, in_ptrs, inputs);
-  RELEASE_ARRAY(env, side_ptrs, sides);
-  RELEASE_ARRAY(env, scalars_, scalars);
-
-  // FIXME: that release causes an error
-  env->ReleaseStringUTFChars(name, cstr_name);
-  return result;
-}
+//JNIEXPORT jfloat JNICALL
+//Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1f(
+//    JNIEnv *env, jobject jobj, jlong ctx, jstring name, jlongArray in_ptrs,
+//    jlongArray side_ptrs, jlong out_ptr, jfloatArray scalars_, jlong m, 
jlong n, jlong out_len, jlong grix) {
+//
+//  SpoofCUDAContext *ctx_ = reinterpret_cast<SpoofCUDAContext *>(ctx);
+//
+//  const char *cstr_name = env->GetStringUTFChars(name, NULL);
+//
+//  float **inputs = reinterpret_cast<float**>(GET_ARRAY(env, in_ptrs));
+//  float **sides = reinterpret_cast<float **>(GET_ARRAY(env, side_ptrs));
+//  float *scalars = reinterpret_cast<float *>(GET_ARRAY(env, scalars_));
+//
+//  float result = ctx_->execute_kernel(
+//      cstr_name, inputs, env->GetArrayLength(in_ptrs), sides, 
env->GetArrayLength(side_ptrs),
+//      reinterpret_cast<float *>(out_ptr), scalars, 
env->GetArrayLength(scalars_), m, n, out_len, grix);
+//
+//  RELEASE_ARRAY(env, in_ptrs, inputs);
+//  RELEASE_ARRAY(env, side_ptrs, sides);
+//  RELEASE_ARRAY(env, scalars_, scalars);
+//
+//  // FIXME: that release causes an error
+//  env->ReleaseStringUTFChars(name, cstr_name);
+//  return result;
+//}
diff --git a/src/main/cuda/spoof-launcher/jni_bridge.h 
b/src/main/cuda/spoof-launcher/jni_bridge.h
index a06bb1b..788f93e 100644
--- a/src/main/cuda/spoof-launcher/jni_bridge.h
+++ b/src/main/cuda/spoof-launcher/jni_bridge.h
@@ -63,16 +63,16 @@ 
Java_org_apache_sysds_hops_codegen_SpoofCompiler_compile_1cuda_1kernel(
  */
 JNIEXPORT jdouble JNICALL
 Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1d(
-    JNIEnv *, jobject, jlong, jstring, jlongArray, jlongArray, jlong, 
jdoubleArray, jlong, jlong, jlong);
+    JNIEnv *, jobject, jlong, jstring, jlongArray, jint, jlongArray, 
jlongArray, jdoubleArray, jlong, jobject, jobject);
 
-/*
- * Class:     org_apache_sysds_runtime_instructions_gpu_SpoofCUDAInstruction
- * Method:    execute_f
- * Signature: (...)Z
- */
-JNIEXPORT jfloat JNICALL
-Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1f(
-    JNIEnv *, jobject, jlong, jstring, jlongArray, jlongArray, jlong, 
jfloatArray, jlong, jlong, jlong);
+///*
+// * Class:     org_apache_sysds_runtime_instructions_gpu_SpoofCUDAInstruction
+// * Method:    execute_f
+// * Signature: (...)Z
+// */
+//JNIEXPORT jfloat JNICALL
+//Java_org_apache_sysds_runtime_codegen_SpoofCUDA_execute_1f(
+//    JNIEnv *, jobject, jlong, jstring, jlongArray, jlongArray, jlong, 
jfloatArray, jlong, jlong, jlong, jlong);
 
 #ifdef __cplusplus
 }
diff --git a/src/main/cuda/spoof/cellwise.cu b/src/main/cuda/spoof/cellwise.cu
index 2f76802..f26161b 100644
--- a/src/main/cuda/spoof/cellwise.cu
+++ b/src/main/cuda/spoof/cellwise.cu
@@ -1,5 +1,4 @@
-%TMP%
-
+/*%TMP%*/SPOOF_OP_NAME
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
@@ -28,27 +27,56 @@
 #include "reduction.cuh"
 #include "spoof_utils.cuh"
 #include "utils.cuh"
+#include "Matrix.h"
 
-template<typename T>
+template<typename T, int NUM_B>
 struct SpoofCellwiseOp {
-   T**b; T* scalars; 
-   int m, n, grix_;
+               MatrixAccessor<T> A;
+               MatrixAccessor<T> b[NUM_B];
+               MatrixAccessor<T> c;
+               T* scalars;
+               T* avals;
+               uint32_t* aix;
+               uint32_t alen;
+               uint32_t& n;
+               uint32_t _grix;
+
+       SpoofCellwiseOp(Matrix<T>* _A, Matrix<T>* _B, Matrix<T>* _C, T* 
scalars, uint32_t grix) :
+                       n(_A->cols), scalars(scalars), _grix(grix) {
+               A.init(_A);
+               c.init(_C);
+               alen = A.row_len(grix);
+
+               if(_B)
+                       for(auto i = 0; i < NUM_B; ++i)
+                               b[i].init(&(_B[i]));
+       }
 
-   SpoofCellwiseOp(T** b, T* scalars, int m, int n, int grix) : 
-       b(b), scalars(scalars), m(m), n(n), grix_(grix) {}
+       __device__  __forceinline__ T operator()(T a, uint32_t idx, uint32_t 
rix, uint32_t cix) {
+//%NEED_RIX%
+//%NEED_CIX%
+//%NEED_GRIX%
 
-   __device__  __forceinline__ T operator()(T a, int idx) const {
-        int rix = idx / n;
-        int cix = idx % n;
-        int grix = grix_ + rix;
 %BODY_dense%
-        return %OUT%;
-   }
+               return %OUT%;
+       }
 };
 
-template<typename T>
-__global__ void %TMP% (T *a, T** b, T* c, T* scalars, int m, int n, int grix) {
-   %AGG_OP%<T> agg_op;
-   SpoofCellwiseOp<T> spoof_op(b, scalars, m, n, grix);
-   %TYPE%<T, %AGG_OP%<T>, SpoofCellwiseOp<T>>(a, c, m, n, %INITIAL_VALUE%, 
agg_op, spoof_op);
+template<typename T, int NUM_B>
+__global__ void /*%TMP%*/SPOOF_OP_NAME_DENSE (Matrix<T>* a, Matrix<T>* b, 
Matrix<T>* c, T* scalars, uint32_t n, uint32_t grix) {
+       %AGG_OP%<T> agg_op;
+       SpoofCellwiseOp<T, NUM_B> spoof_op(a, b, c, scalars, grix);
+       %TYPE%<T, %AGG_OP%<T>, SpoofCellwiseOp<T, NUM_B>>(&(spoof_op.A), 
&(spoof_op.c), n, %INITIAL_VALUE%, agg_op, spoof_op);
 };
+
+template<typename T, int NUM_B>
+__global__ void /*%TMP%*/SPOOF_OP_NAME_SPARSE (Matrix<T>* a, Matrix<T>* b, 
Matrix<T>* c, T* scalars, uint32_t n, uint32_t grix) {
+       %AGG_OP%<T> agg_op;
+       SpoofCellwiseOp<T, NUM_B> spoof_op(a, b, c, scalars, grix);
+       %TYPE%_SPARSE<T, %AGG_OP%<T>, SpoofCellwiseOp<T, NUM_B>>(&(spoof_op.A), 
&(spoof_op.c), n, %INITIAL_VALUE%, agg_op, spoof_op);
+
+//     if(blockIdx.x == 0 && threadIdx.x == 0) {
+//             for(auto i = 0; i < 30; ++i)
+//                     printf("%4.3f ", spoof_op.c.val(i));
+//     }
+};
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
index bbe5454..80192b5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNode.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.codegen.cplan;
 
 import java.util.ArrayList;
 
+import org.apache.sysds.hops.codegen.SpoofCompiler;
 import org.apache.sysds.hops.codegen.template.TemplateUtils;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
@@ -84,13 +85,23 @@ public abstract class CNode
 
        public String getVarname(GeneratorAPI api) { return getVarname(); }
 
-       public String getVectorLength() {
-               if( getVarname().startsWith("a") )
-                       return "len";
-               else if( getVarname().startsWith("b") )
-                       return getVarname()+".clen";
-               else if( _dataType==DataType.MATRIX )
-                       return getVarname()+".length";
+       public String getVectorLength(GeneratorAPI api) {
+               if(api == GeneratorAPI.CUDA) {
+                       if( getVarname().startsWith("a") )
+                               return "a.cols()";
+                       if(getVarname().startsWith("b"))
+                               return getVarname()+".cols()";
+                       else                            
+                               return getVarname()+".length";
+               }
+               else {
+                       if( getVarname().startsWith("a") )
+                               return "len";
+                       if(getVarname().startsWith("b"))
+                               return getVarname() + ".clen";
+                       else if(_dataType == DataType.MATRIX)
+                               return getVarname() + ".length";
+               }
                return "";
        }
        
@@ -210,13 +221,17 @@ public abstract class CNode
                        && _literal == cthat._literal;
        }
        
-       protected String replaceUnaryPlaceholders(String tmp, String varj, 
boolean vectIn) {
+       protected String replaceUnaryPlaceholders(String tmp, String varj, 
boolean vectIn, GeneratorAPI api) {
                //replace sparse and dense inputs
                tmp = tmp.replace("%IN1v%", varj+"vals");
                tmp = tmp.replace("%IN1i%", varj+"ix");
                tmp = tmp.replace("%IN1%", 
-                       (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? 
varj + ".values(rix)" :
-                       (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? 
varj + ".values(0)" : varj));
+                       (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? 
+                               ((api == GeneratorAPI.JAVA) ? varj + 
".values(rix)" : varj + ".vals(0)" ) :
+                               (vectIn && 
TemplateUtils.isRowVector(_inputs.get(0)) ? 
+                                       ((api == GeneratorAPI.JAVA) ? varj + 
".values(0)" : varj + ".val(0)") :
+                                               (varj.startsWith("a") || 
TemplateUtils.isMatrix(_inputs.get(0))) ?
+                                                               (api == 
GeneratorAPI.JAVA ? varj : varj + ".vals(0)") : varj));
                
                //replace start position of main input
                String spos = (_inputs.get(0) instanceof CNodeData 
@@ -228,7 +243,7 @@ public abstract class CNode
                
                //replace length
                if( _inputs.get(0).getDataType().isMatrix() )
-                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength());
+                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength(api));
                
                return tmp;
        }
@@ -236,33 +251,45 @@ public abstract class CNode
        protected CodeTemplate getLanguageTemplateClass(CNode caller, 
GeneratorAPI api) {
                switch (api) {
                        case CUDA:
-                               if(caller instanceof CNodeCell)
-                                       return new 
org.apache.sysds.hops.codegen.cplan.cuda.CellWise();
-                               else if (caller instanceof CNodeUnary)
-                                       return new 
org.apache.sysds.hops.codegen.cplan.cuda.Unary();
-                               else if (caller instanceof CNodeBinary)
+                               if(caller instanceof CNodeBinary)
                                        return new 
org.apache.sysds.hops.codegen.cplan.cuda.Binary();
-                               else if (caller instanceof CNodeTernary)
+                               else if(caller instanceof CNodeTernary)
                                        return new 
org.apache.sysds.hops.codegen.cplan.cuda.Ternary();
-                               else
-                                       return null;
-                       case JAVA:
-                               if(caller instanceof CNodeCell)
-                                       return new 
org.apache.sysds.hops.codegen.cplan.java.CellWise();
-                               else if (caller instanceof CNodeUnary)
-                                       return new 
org.apache.sysds.hops.codegen.cplan.java.Unary();
-                               else if (caller instanceof CNodeBinary)
+                               else if(caller instanceof CNodeUnary)
+                                       return new 
org.apache.sysds.hops.codegen.cplan.cuda.Unary();
+                               else return null;
+                       case JAVA: 
+                               if(caller instanceof CNodeBinary)
                                        return new 
org.apache.sysds.hops.codegen.cplan.java.Binary();
-                               else if (caller instanceof CNodeTernary)
+                               else if(caller instanceof CNodeTernary)
                                        return new 
org.apache.sysds.hops.codegen.cplan.java.Ternary();
-
-                               else
-                                       return null;
+                               else if(caller instanceof CNodeUnary)
+                                       return new 
org.apache.sysds.hops.codegen.cplan.java.Unary();
+                               else return null;
                        default:
                                throw new RuntimeException("API not supported 
by code generator: " + api.toString());
                }
        }
-
+       
+       protected String getLanguageTemplate(CNode caller, GeneratorAPI api) {
+               switch (api) {
+                       case CUDA:
+                               if(caller instanceof CNodeCell)
+                                       return 
CodeTemplate.getTemplate("/cuda/spoof/cellwise.cu");
+                               else if(caller instanceof CNodeRow)
+                                       return 
CodeTemplate.getTemplate("/cuda/spoof/rowwise.cu");
+                               else return null;
+                       case JAVA:
+                               if(caller instanceof CNodeCell)
+                                       return 
CodeTemplate.getTemplate("/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template");
+                               else if(caller instanceof CNodeRow)
+                                       return 
CodeTemplate.getTemplate("/java/org/apache/sysds/hops/codegen/cplan/java/Rowwise.java.template");
+                               else return null;
+                       default:
+                               throw new RuntimeException("API not supported 
by code generator: " + api.toString());
+               }
+       }
+       
        public abstract boolean isSupported(GeneratorAPI api);
        
        public void setVarName(String name) { _genVar = name; }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 15a26bc..925e055 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -113,6 +113,9 @@ public class CNodeBinary extends CNode {
                        String [] tmp = this.name().split("_");
                        return StringUtils.capitalize(tmp[1].toLowerCase());
                }
+               
+               public boolean isNotSupportedBySpoofCUDA() {
+                       return this == VECT_BIASADD || this == VECT_BIASMULT;}
        }
        
        private final BinType _type;
@@ -157,7 +160,6 @@ public class CNodeBinary extends CNode {
                boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
                        && _inputs.get(1).getDataType().isMatrix());
                String var = createVarname();
-//             String tmp = _type.getTemplate(api, lang, lsparseLhs, 
lsparseRhs, scalarVector, scalarInput);
                String tmp = getLanguageTemplateClass(this, 
api).getTemplate(_type, lsparseLhs, lsparseRhs, scalarVector, scalarInput);
 
                tmp = tmp.replace("%TMP%", var);
@@ -169,24 +171,29 @@ public class CNodeBinary extends CNode {
                        //replace sparse and dense inputs
                        tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals");
                        tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix");
-                       tmp = tmp.replace("%IN"+(j+1)+"%", 
-                               varj.startsWith("b") ? varj + ".values(rix)" : 
varj );
+                       tmp = tmp.replace("%IN"+(j+1)+"%",
+                                       varj.startsWith("a") ? (api == 
GeneratorAPI.JAVA ? varj : 
+                                               (_inputs.get(j).getDataType() 
== DataType.MATRIX ? varj + ".vals(0)" : varj)) :
+//                                     varj.startsWith("b") ? (api == 
GeneratorAPI.JAVA ? varj + ".values(rix)" : varj + ".vals(0)") : varj);
+                                               varj.startsWith("b") ? (api == 
GeneratorAPI.JAVA ? varj + ".values(rix)" : 
+                                                               (_type == 
BinType.VECT_MATRIXMULT ? varj : varj + ".vals(0)")) :
+                                                       
_inputs.get(j).getDataType() == DataType.MATRIX ? (api == GeneratorAPI.JAVA ? 
varj : varj + ".vals(0)") : varj);
                        
                        //replace start position of main input
                        tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) 
instanceof CNodeData 
-                               && _inputs.get(j).getDataType().isMatrix()) ? 
(!varj.startsWith("b")) ? varj+"i" : 
-                               (TemplateUtils.isMatrix(_inputs.get(j)) && 
_type!=BinType.VECT_MATRIXMULT) ? 
-                               varj + ".pos(rix)" : "0" : "0");
+                                       && 
_inputs.get(j).getDataType().isMatrix()) ? (!varj.startsWith("b")) ? varj+"i" : 
+                                       (TemplateUtils.isMatrix(_inputs.get(j)) 
&& _type!=BinType.VECT_MATRIXMULT) ? 
+                                       varj + ".pos(rix)" : "0" : "0");
                }
                //replace length information (e.g., after matrix mult)
                if( _type == BinType.VECT_OUTERMULT_ADD ) {
                        for( int j=0; j<2; j++ )
-                               tmp = tmp.replace("%LEN"+(j+1)+"%", 
_inputs.get(j).getVectorLength());
+                               tmp = tmp.replace("%LEN"+(j+1)+"%", 
_inputs.get(j).getVectorLength(api));
                }
                else { //general case 
                        CNode mInput = getIntermediateInputVector();
                        if( mInput != null )
-                               tmp = tmp.replace("%LEN%", 
mInput.getVectorLength());
+                               tmp = tmp.replace("%LEN%", 
mInput.getVectorLength(api));
                }
                
                sb.append(tmp);
@@ -418,6 +425,11 @@ public class CNodeBinary extends CNode {
        @Override
        public boolean isSupported(GeneratorAPI api) {
                boolean is_supported = (api == GeneratorAPI.CUDA || api == 
GeneratorAPI.JAVA);
+               
+               // ToDo: support these
+               if(api == GeneratorAPI.CUDA)
+                       is_supported = !_type.isNotSupportedBySpoofCUDA();
+               
                int i = 0;
                while(is_supported && i < _inputs.size()) {
                        CNode in = _inputs.get(i++);
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index 74292b6..f2f0179 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -119,16 +119,42 @@ public class CNodeCell extends CNodeTpl
        public String codegen(boolean sparse, GeneratorAPI _api) {
                api = _api;
 
-               String tmp = getLanguageTemplateClass(this, 
api).getTemplate(_type);
+               String tmp = getLanguageTemplate(this, api);
 
                //generate dense/sparse bodies
                String tmpDense = _output.codegen(false, api);
+               // ToDo: workaround to fix name clash of cell and row template
+               if(api == GeneratorAPI.CUDA)
+                       tmpDense = tmpDense.replace("a.vals(0)", "a");
                _output.resetGenerated();
                
+               String varName; 
                if(getVarname() == null)
-                       tmp = tmp.replace("%TMP%", createVarname());
+//                     tmp = tmp.replace("%TMP%", createVarname());
+                       varName = createVarname();
                else
-                       tmp = tmp.replace("%TMP%", getVarname());
+//                     tmp = tmp.replace("%TMP%", getVarname());
+                       varName = getVarname();
+               
+               if(api == GeneratorAPI.JAVA)
+                       tmp = tmp.replace("%TMP%", varName);
+               else
+                       tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
+               
+               if(tmpDense.contains("grix"))
+                       tmp = tmp.replace("//%NEED_GRIX%", "\t\tuint32_t 
grix=_grix + rix;");
+               else
+                       tmp = tmp.replace("//%NEED_GRIX%", "");
+               
+//             if(tmpDense.contains("rix"))
+//                     tmp = tmp.replace("//%NEED_RIX%", "\t\tuint32_t rix = 
idx / A.cols();\n");
+//             else
+                       tmp = tmp.replace("//%NEED_RIX%", "");
+               
+//             if(tmpDense.contains("cix"))
+//                     tmp = tmp.replace("//%NEED_CIX%", "\t\tuint32_t cix = 
idx % A.cols();");
+//             else
+                       tmp = tmp.replace("//%NEED_CIX%", "");
                
                tmp = tmp.replace("%BODY_dense%", tmpDense);
                
@@ -140,7 +166,10 @@ public class CNodeCell extends CNodeTpl
                tmp = tmp.replace("%AGG_OP_NAME%", (_aggOp != null) ? "AggOp." 
+ _aggOp.name() : "null");
                tmp = tmp.replace("%SPARSE_SAFE%", 
String.valueOf(isSparseSafe()));
                tmp = tmp.replace("%SEQ%", String.valueOf(containsSeq()));
-
+               
+               // maybe empty lines
+               //tmp = tmp.replaceAll("(?m)^[ \t]*\r?\n", "");
+               
                if(api == GeneratorAPI.CUDA) {
                        // ToDo: initial_value is misused to pass VT (values 
per thread) to no_agg operator
                        String agg_op = "IdentityOp";
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
index b91c66f..940f3e2 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
@@ -23,6 +23,7 @@ import org.apache.commons.lang.StringUtils;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.hops.codegen.SpoofCompiler;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
@@ -92,7 +93,7 @@ public class CNodeData extends CNode
                                return isSinglePrecision() ? "-CUDART_INF_F" : 
"-CUDART_INF";
                        else if ("true".equals(_name) || "false".equals(_name))
                                return "true".equals(_name) ? "1" : "0";
-                       else if (StringUtils.isNumeric(_name))
+                       else if (CodegenUtils.isNumeric(_name))
                                return isSinglePrecision() ? _name + ".0f" : 
_name + ".0";
                        else
                                return _name;
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java
index 5500ddb..dcf18ec 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeNary.java
@@ -115,7 +115,7 @@ public class CNodeNary extends CNode
        public String codegen(boolean sparse, GeneratorAPI api) {
                if( isGenerated() )
                        return "";
-               
+                               
                StringBuilder sb = new StringBuilder();
                
                //generate children
@@ -134,8 +134,8 @@ public class CNodeNary extends CNode
                String varj1 = _inputs.get(0).getVarname();
                String varj2 = _inputs.get(1).getVarname();
                tmp = (_type == NaryType.VECT_CONV2DMM) ?
-                       replaceBinaryPlaceholders(tmp, new 
String[]{varj1,varj2}, false) :
-                       replaceUnaryPlaceholders(tmp, varj1, false);
+                       replaceBinaryPlaceholders(tmp, new 
String[]{varj1,varj2}, false, api) :
+                       replaceUnaryPlaceholders(tmp, varj1, false, api);
                
                sb.append(tmp);
                
@@ -251,7 +251,7 @@ public class CNodeNary extends CNode
        }
        
 
-       private String replaceBinaryPlaceholders(String tmp, String[] vars, 
boolean vectIn) {
+       private String replaceBinaryPlaceholders(String tmp, String[] vars, 
boolean vectIn, GeneratorAPI api) {
                //replace sparse and dense inputs
                for( int j=0; j<2; j++ ) {
                        String varj = vars[j];
@@ -259,8 +259,11 @@ public class CNodeNary extends CNode
                        //replace sparse and dense inputs
                        tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals");
                        tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix");
-                       tmp = tmp.replace("%IN"+(j+1)+"%", 
-                               varj.startsWith("b") ? varj + ".values(rix)" : 
varj );
+//                     tmp = tmp.replace("%IN"+(j+1)+"%", 
+//                             varj.startsWith("b") ? varj + ".values(rix)" : 
varj );
+                       tmp = tmp.replace("%IN"+(j+1)+"%",
+                               varj.startsWith("b") ? ((api == 
GeneratorAPI.JAVA) ? varj + ".values(rix)" :
+                                       varj + ".vals(0)") : varj);
                        
                        //replace start position of main input
                        tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) 
instanceof CNodeData 
@@ -270,7 +273,7 @@ public class CNodeNary extends CNode
                
                //replace length
                if( _inputs.get(0).getDataType().isMatrix() )
-                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength());
+                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength(api));
                
                return tmp;
        }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
index 2ff9054..6ec252a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
@@ -76,6 +76,11 @@ public class CNodeUnary extends CNode
                                POW2, MULT2, ABS, ROUND, CEIL, FLOOR, SIGN, 
                                SIN, TAN, SPROP}, this);
                }
+               
+               public boolean isNotSupportedBySpoofCUDA() {
+                       return this == VECT_CUMSUM || this == VECT_CUMMIN || 
this == VECT_CUMMAX|| 
+                                       this == VECT_SPROP || this == 
VECT_SIGMOID;
+               }
        }
        
        private UnaryType _type;
@@ -110,12 +115,12 @@ public class CNodeUnary extends CNode
                        && !_inputs.get(0).isLiteral());
                String var = createVarname();
                String tmp = getLanguageTemplateClass(this, 
api).getTemplate(_type, lsparse);
-               tmp = tmp.replace("%TMP%", var);
+               tmp = tmp.replaceAll("%TMP%", var);
                
                //replace sparse and dense inputs
                String varj = _inputs.get(0).getVarname();
                boolean vectIn = varj.startsWith("b") && 
!_type.isScalarLookup();
-               tmp = replaceUnaryPlaceholders(tmp, varj, vectIn);
+               tmp = replaceUnaryPlaceholders(tmp, varj, vectIn, api);
                
                sb.append(tmp);
                
@@ -260,9 +265,15 @@ public class CNodeUnary extends CNode
                return super.equals(that)
                        && _type == that._type;
        }
+                       
        @Override
        public boolean isSupported(GeneratorAPI api) {
                boolean is_supported = (api == GeneratorAPI.CUDA || api == 
GeneratorAPI.JAVA);
+               
+               // ToDo: support these
+               if(api == GeneratorAPI.CUDA)
+                       is_supported = !_type.isNotSupportedBySpoofCUDA();
+               
                int i = 0;
                while(is_supported && i < _inputs.size()) {
                        CNode in = _inputs.get(i++);
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
index 8a8a3be..34f0b66 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
@@ -19,18 +19,52 @@
 
 package org.apache.sysds.hops.codegen.cplan;
 
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.io.IOUtilFunctions;
 
-public interface CodeTemplate {
+import java.io.FileInputStream;
+import java.io.IOException;
 
-    String getTemplate();
-
-    String getTemplate(CNodeUnary.UnaryType type, boolean sparse);
-
-    String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean 
sparseRhs, boolean scalarVector,
-                              boolean scalarInput);
-
-    String getTemplate(CNodeTernary.TernaryType type, boolean sparse);
-
-    String getTemplate(SpoofCellwise.CellType ct);
+public abstract class CodeTemplate {
+       
+       public String getTemplate() {
+               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
+       }
+       
+       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector,
+               boolean scalarInput) {
+               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
+       }
+       
+       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
+               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
+       }
+       
+       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
+               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
+       }
+       
+       public static String getTemplate(String templateFileName) {
+               try {
+                       // Change prefix to the code template file if running 
from jar. File were extracted to a temporary
+                       // directory in that case. By default we load the 
template from the source tree.
+                       
if(CodeTemplate.class.getProtectionDomain().getCodeSource().getLocation().getPath().contains(".jar"))
 {
+                               if(templateFileName.contains(".java")) {
+                                       templateFileName = templateFileName
+                                               
.replace("/java/org/apache/sysds/hops/codegen/cplan/java/", "/java/spoof/");
+                               }
+                               return (IOUtilFunctions.toString(new 
FileInputStream(ConfigurationManager.getDMLConfig()
+                                       .getTextValue(DMLConfig.LOCAL_TMP_DIR) 
+ templateFileName)));
+                       }
+                       else
+                               return IOUtilFunctions.toString(new 
FileInputStream(System.getProperty("user.dir") +
+                                       "/src/main" + templateFileName));
+               }
+               catch(IOException e) {
+                       System.out.println(e.getMessage());
+                       return null;
+               }
+       }
+       
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
index 0365afd..30b48da 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java
@@ -20,29 +20,13 @@
 package org.apache.sysds.hops.codegen.cplan.cuda;
 
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
-import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
-import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
 
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
 
-public class Binary implements CodeTemplate {
+public class Binary extends CodeTemplate {
+       
        @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
        public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector,
                                                          boolean scalarInput) {
 
@@ -263,11 +247,11 @@ public class Binary implements CodeTemplate {
 
                                //scalar-scalar operations
                                case MULT:
-                                       return "        T %TMP% = %IN1% * 
%IN2%;\n";
+                                       return "                T %TMP% = %IN1% 
* %IN2%;\n";
                                case DIV:
                                        return "        T %TMP% = %IN1% / 
%IN2%;\n";
                                case PLUS:
-                                       return "        T %TMP% = %IN1% + 
%IN2%;\n";
+                                       return "                T %TMP% = %IN1% 
+ %IN2%;\n";
                                case MINUS:
                                        return "        T %TMP% = %IN1% - 
%IN2%;\n";
                                case MODULUS:
@@ -307,16 +291,11 @@ public class Binary implements CodeTemplate {
                                case BITWAND:
                                        return "        T %TMP% = bwAnd(%IN1%, 
%IN2%);\n";
                                case SEQ_RIX:
-                                       return "        T %TMP% = %IN1% + grix 
* %IN2%;\n"; //0-based global rix
+                                       return "                T %TMP% = %IN1% 
+ grix * %IN2%;\n"; //0-based global rix
 
                                default:
                                        throw new RuntimeException("Invalid 
binary type: " + this.toString());
                        }
                }
        }
-
-       @Override
-       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/CellWise.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/CellWise.java
deleted file mode 100644
index beb2398..0000000
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/CellWise.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.codegen.cplan.cuda;
-
-import java.io.FileInputStream;
-import java.io.IOException;
-
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
-import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
-import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
-import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
-import org.apache.sysds.runtime.io.IOUtilFunctions;
-
-
-// ToDo: clean code template and load from file
-public class CellWise implements CodeTemplate {
-
-       private static final String TEMPLATE_PATH = "/cuda/spoof/cellwise.cu";
-
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               try {
-                       // Change prefix to the code template file if running 
from jar. File were extracted to a temporary
-                       // directory in that case. By default we load the 
template from the source tree.
-                       
if(CellWise.class.getProtectionDomain().getCodeSource().getLocation().getPath().contains(".jar"))
-                               return(IOUtilFunctions.toString(new 
FileInputStream(ConfigurationManager.getDMLConfig()
-                                               
.getTextValue(DMLConfig.LOCAL_TMP_DIR) + TEMPLATE_PATH)));
-                       else
-                               return IOUtilFunctions.toString(new 
FileInputStream(System.getProperty("user.dir") +
-                                               "/src/main" + TEMPLATE_PATH));
-               }
-               catch(IOException e) {
-                       System.out.println(e.getMessage());
-                       return null;
-               }
-       }
-
-       @Override
-       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-}
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
index 355e579..dd06d6c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Ternary.java
@@ -19,15 +19,12 @@
 
 package org.apache.sysds.hops.codegen.cplan.cuda;
 
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
-import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
 
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
 
-public class Ternary implements CodeTemplate {
+public class Ternary extends CodeTemplate {
 
        @Override
        public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
@@ -58,10 +55,11 @@ public class Ternary implements CodeTemplate {
                                case LOOKUP_RC1:
                                        return sparse ?
                                                        "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
-                                                       "       T %TMP% = 
getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
+//                                                     "       T %TMP% = 
getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
+                                                       "               T %TMP% 
= %IN1%.val(rix, %IN3%-1);\n";
 
                                case LOOKUP_RVECT1:
-                                       return "        T[] %TMP% = 
getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
+                                       return "\t\tVector<T>& %TMP% = 
getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
 
                                default:
                                        throw new RuntimeException("Invalid 
ternary type: " + this.toString());
@@ -94,10 +92,12 @@ public class Ternary implements CodeTemplate {
                                case LOOKUP_RC1:
                                        return sparse ?
                                                        "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
-                                                       "       T %TMP% = 
getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
-
+//                                                     "       T %TMP% = 
getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
+                                                       "               T %TMP% 
= %IN1%.val(rix, %IN3%-1);\n";
+                               
+                               
                                case LOOKUP_RVECT1:
-                                       return "        T[] %TMP% = 
getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
+                                       return "\t\tVector<T>& %TMP% = 
getVector(%IN1%, %IN2%, rix, %IN3%-1, this);\n";
 
                                default:
                                        throw new RuntimeException("Invalid 
ternary type: "+this.toString());
@@ -105,26 +105,4 @@ public class Ternary implements CodeTemplate {
 
                }
        }
-
-       @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector,
-                                                         boolean scalarInput) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
index 0b5852d..459d1c8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Unary.java
@@ -20,15 +20,13 @@
 package org.apache.sysds.hops.codegen.cplan.cuda;
 
 import org.apache.commons.lang.StringUtils;
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
-import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
 
 import static 
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
 
-public class Unary implements CodeTemplate {
+public class Unary extends CodeTemplate {
+
        @Override
        public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
                if(isSinglePrecision()) {
@@ -179,7 +177,8 @@ public class Unary implements CodeTemplate {
                                case LOOKUP_R:
                                        return sparse ?
                                                "       T %TMP% = 
getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
-                                               "       T %TMP% = 
getValue(%IN1%, rix);\n";
+                                               "               T %TMP% = 
%IN1%.val(rix);\n";
+//                                             "       T %TMP% = 
getValue(%IN1%, rix);\n";
                                case LOOKUP_C:
                                        return "        T %TMP% = 
getValue(%IN1%, n, 0, cix);\n";
                                case LOOKUP_RC:
@@ -211,11 +210,11 @@ public class Unary implements CodeTemplate {
                                case TANH:
                                        return "        T %TMP% = 
tanh(%IN1%);\n";
                                case SIGN:
-                                       return "        T %TMP% = 
signbit(%IN1%) == 0 ? 1.0f : -1.0f;\n";
+                                       return "        T %TMP% = 
signbit(%IN1%) == 0 ? 1.0 : -1.0;\n";
                                case SQRT:
                                        return "        T %TMP% = 
sqrt(%IN1%);\n";
                                case LOG:
-                                       return "        T %TMP% = 
log(%IN1%);\n";
+                                       return "                T %TMP% = 
log(%IN1%);\n";
                                case ROUND:
                                        return "        T %TMP% = 
round(%IN1%);\n";
                                case CEIL:
@@ -235,24 +234,4 @@ public class Unary implements CodeTemplate {
 
                }
        }
-
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index 28b970a..1453b44 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -20,13 +20,10 @@
 package org.apache.sysds.hops.codegen.cplan.java;
 
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
-import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
-import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
 
-public class Binary implements CodeTemplate {
-       @Override
+public class Binary extends CodeTemplate {
+
        public String getTemplate(BinType type, boolean sparseLhs, boolean 
sparseRhs, boolean scalarVector,
                                                          boolean scalarInput) {
 
@@ -178,23 +175,4 @@ public class Binary implements CodeTemplate {
                }
        }
 
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java
deleted file mode 100644
index 85476a7..0000000
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/CellWise.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.codegen.cplan.java;
-
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
-import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
-import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
-import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
-
-public class CellWise implements CodeTemplate {
-       public static final String TEMPLATE =
-                       "package codegen;\n"
-                       + "import 
org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
-                       + "import 
org.apache.sysds.runtime.codegen.SpoofCellwise;\n"
-                       + "import 
org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;\n"
-                       + "import 
org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;\n"
-                       + "import 
org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\n"
-                       + "import org.apache.commons.math3.util.FastMath;\n"
-                       + "\n"
-                       + "public final class %TMP% extends SpoofCellwise {\n"
-                       + "  public %TMP%() {\n"
-                       + "    super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, 
%AGG_OP_NAME%);\n"
-                       + "  }\n"
-                       + "  protected double genexec(double a, SideInput[] b, 
double[] scalars, int m, int n, long grix, int rix, int cix) { \n"
-                       + "%BODY_dense%"
-                       + "    return %OUT%;\n"
-                       + "  }\n"
-                       + "}\n";
-
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               switch(ct) {
-                       case NO_AGG:
-                       case FULL_AGG:
-                       case ROW_AGG:
-                       case COL_AGG:
-                       default:
-                               return TEMPLATE;
-               }
-       }
-
-       @Override
-       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-}
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
similarity index 57%
copy from src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
copy to 
src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
index 8a8a3be..84183ae 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CodeTemplate.java
+++ 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Cellwise.java.template
@@ -17,20 +17,19 @@
  * under the License.
  */
 
-package org.apache.sysds.hops.codegen.cplan;
-
+package codegen;
+import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
 import org.apache.sysds.runtime.codegen.SpoofCellwise;
-
-public interface CodeTemplate {
-
-    String getTemplate();
-
-    String getTemplate(CNodeUnary.UnaryType type, boolean sparse);
-
-    String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean 
sparseRhs, boolean scalarVector,
-                              boolean scalarInput);
-
-    String getTemplate(CNodeTernary.TernaryType type, boolean sparse);
-
-    String getTemplate(SpoofCellwise.CellType ct);
-}
+import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;
+import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;
+import org.apache.commons.math3.util.FastMath;
+
+/* This is a SPOOF code generation template */
+public final class %TMP% extends SpoofCellwise {
+       public %TMP%() { super(CellType.%TYPE%, %SPARSE_SAFE%, %SEQ%, 
%AGG_OP_NAME%); }
+
+       protected double genexec(double a, SideInput[] b, double[] scalars, int 
m, int n, long grix, int rix, int cix) {
+%BODY_dense%   return %OUT%;
+       }
+};
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java
index a499f49..a86d51c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Ternary.java
@@ -19,13 +19,10 @@
 
 package org.apache.sysds.hops.codegen.cplan.java;
 
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
-import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
 
-public class Ternary implements CodeTemplate {
+public class Ternary extends CodeTemplate {
 
        @Override
        public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
@@ -64,26 +61,4 @@ public class Ternary implements CodeTemplate {
                                throw new RuntimeException("Invalid ternary 
type: "+this.toString());
                }
        }
-
-       @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector,
-                                                         boolean scalarInput) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeUnary.UnaryType type, boolean sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
 }
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
index 5f6a392..43f1b5d 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
@@ -20,13 +20,10 @@
 package org.apache.sysds.hops.codegen.cplan.java;
 
 import org.apache.commons.lang.StringUtils;
-import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
-import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysds.hops.codegen.cplan.CNodeUnary.UnaryType;
 import org.apache.sysds.hops.codegen.cplan.CodeTemplate;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
 
-public class Unary implements CodeTemplate {
+public class Unary extends CodeTemplate {
        @Override
        public String getTemplate(UnaryType type, boolean sparse) {
                switch( type ) {
@@ -129,24 +126,4 @@ public class Unary implements CodeTemplate {
                                throw new RuntimeException("Invalid unary type: 
"+this.toString());
                }
        }
-
-       @Override
-       public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, 
boolean sparseRhs, boolean scalarVector, boolean scalarInput) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate() {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(SpoofCellwise.CellType ct) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
-
-       @Override
-       public String getTemplate(CNodeTernary.TernaryType type, boolean 
sparse) {
-               throw new RuntimeException("Calling wrong getTemplate method on 
" + getClass().getCanonicalName());
-       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
index e0fc9c4..f3ccf12 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofCUDA.java
@@ -30,19 +30,22 @@ import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-import static 
org.apache.sysds.runtime.matrix.data.LibMatrixNative.isSinglePrecision;
-
 public class SpoofCUDA extends SpoofOperator {
        private static final long serialVersionUID = -2161276866245388359L;
        
        private final CNodeTpl cnt;
+       private final Class<?> java_op;
        public final String name;
+       public final String src;
 
-       public SpoofCUDA(CNodeTpl cnode) {
+       public SpoofCUDA(String source, CNodeTpl cnode, Class<?> java_op) {
                name = "codegen." + cnode.getVarname();
                cnt = cnode;
+               src = source;
+               this.java_op = java_op;
        }
 
        public String getName() {
@@ -71,40 +74,65 @@ public class SpoofCUDA extends SpoofOperator {
        }
 
        public double execute(ArrayList<MatrixObject> inputs, 
ArrayList<ScalarObject> scalarObjects, MatrixObject out_obj,
-                                                          ExecutionContext ec) 
{
+                                                          ExecutionContext ec, 
boolean sparseOut) {
                double ret;
-               long out_ptr = 0;
-
-               if(out_obj != null)
-                       out_ptr = ec.getGPUPointerAddress(out_obj);
-
+               long[] out_ptr = {0,0,0,0};
+               
+               if(out_obj != null) {
+                       if(sparseOut) {
+                               out_ptr[0] = 
ec.getGPUSparsePointerAddress(out_obj).nnz;
+                               out_ptr[1] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(out_obj).rowPtr);
+                               out_ptr[2] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(out_obj).colInd);
+                               out_ptr[3] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(out_obj).val);
+                       }
+                       else
+                               out_ptr[3] = 
ec.getGPUDensePointerAddress(out_obj);
+               }
+               
                int offset = 1;
                if(cnt instanceof CNodeOuterProduct)
                        offset = 2;
-
+               
                // only dense input preparation for now
-               long[] in_ptrs = new long[offset];
-               for(int i = 0; i < offset; ++i)
-                       in_ptrs[i] = ec.getGPUPointerAddress(inputs.get(i));
-
-               long[] side_ptrs = new long[inputs.size() - offset];
-               for(int i = offset; i < inputs.size(); ++i)
-                       side_ptrs[i - offset] = 
ec.getGPUPointerAddress(inputs.get(i));
-
-               if(isSinglePrecision()) {
-                       float[] scalars = prepInputScalarsFloat(scalarObjects);
-
-                       // ToDo: handle float
-                  ret = 
execute_f(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), 
name.split("\\.")[1],
-                                       in_ptrs, side_ptrs, out_ptr, scalars, 
inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 0);
-
+               long[] in_ptrs = new long[offset * 4];
+               for(int i = 0; i < offset; i += 4) {
+                       
if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) {
+                               in_ptrs[i * 4] = 
ec.getGPUSparsePointerAddress(inputs.get(i)).nnz;
+                               in_ptrs[i * 4 + 1] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr);
+                               in_ptrs[i * 4 + 2] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd);
+                               in_ptrs[i * 4 + 3] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val);
+                       }
+                       else
+                               in_ptrs[i * 4 + 3] = 
ec.getGPUDensePointerAddress(inputs.get(i));
+               }
+               
+               long[] side_ptrs = new long[(inputs.size() - offset) * 4];
+               for(int i = offset; i < inputs.size(); i++) {
+                       int j = (i - offset)  * 4;
+                       
if(inputs.get(i).getGPUObject(ec.getGPUContext(0)).isSparse()) {
+                               side_ptrs[j] = 
ec.getGPUSparsePointerAddress(inputs.get(i)).nnz;
+                               side_ptrs[j + 1] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).rowPtr);
+                               side_ptrs[j + 2] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).colInd);
+                               side_ptrs[j + 3] = 
GPUObject.getPointerAddress(ec.getGPUSparsePointerAddress(inputs.get(i)).val);
+                       }
+                       else
+                               side_ptrs[j + 3] = 
ec.getGPUDensePointerAddress(inputs.get(i));
                }
-               else {
-                       double[] scalars = prepInputScalars(scalarObjects);
 
+//             if(isSinglePrecision()) {
+//                     float[] scalars = prepInputScalarsFloat(scalarObjects);
+//
+//                     // ToDo: handle float
+//                ret = 
execute_f(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), 
name.split("\\.")[1],
+//                                     in_ptrs, side_ptrs, out_ptr, scalars, 
inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 
out_obj.getNumColumns(),0);
+//
+//             }
+//             else {
+                       double[] scalars = prepInputScalars(scalarObjects);
+               
                        ret = 
execute_d(SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA), 
name.split("\\.")[1],
-                                       in_ptrs, side_ptrs, out_ptr, scalars, 
inputs.get(0).getNumRows(), inputs.get(0).getNumColumns(), 0);
-               }
+                                       in_ptrs, offset, side_ptrs, out_ptr, 
scalars,0, inputs, out_obj);
+//             }
                return ret;
        }
 
@@ -115,8 +143,8 @@ public class SpoofCUDA extends SpoofOperator {
        }
 
        private native float execute_f(long ctx, String name, long[] in_ptr, 
long[] side_ptr,
-                                                                  long 
out_ptr, float[] scalars, long m, long n, long grix);
+                                                                  long 
out_ptr, float[] scalars, long m, long n, long out_len, long grix);
 
-       private native double execute_d(long ctx, String name, long[] in_ptr, 
long[] side_ptr,
-                                                                       long 
out_ptr, double[] scalars, long m, long n, long grix);
+       private native double execute_d(long ctx, String name, long[] in_ptr, 
int offset, long[] side_ptr, long[] out_ptr,
+                       double[] scalars, long grix, ArrayList<MatrixObject> 
inputs, MatrixObject output);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 9587b95..deaa680 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -43,6 +43,7 @@ import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
+import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.lineage.Lineage;
@@ -366,7 +367,7 @@ public class ExecutionContext {
        public Pair<MatrixObject, Boolean> 
getSparseMatrixOutputForGPUInstruction(String varName, long numRows, long 
numCols, long nnz) {
                MatrixObject mo = allocateGPUMatrixObject(varName, numRows, 
numCols);
                mo.getDataCharacteristics().setNonZeros(nnz);
-                               boolean allocated = 
mo.getGPUObject(getGPUContext(0)).acquireDeviceModifySparse();
+               boolean allocated = 
mo.getGPUObject(getGPUContext(0)).acquireDeviceModifySparse();
                return new Pair<>(mo, allocated);
        }
 
@@ -405,15 +406,21 @@ public class ExecutionContext {
                mo.getGPUObject(getGPUContext(0)).addWriteLock();
                return mo;
        }
-
-       public long getGPUPointerAddress(MatrixObject obj) {
-
-                       if(obj.getGPUObject(getGPUContext(0)) == null)
+       
+       public long getGPUDensePointerAddress(MatrixObject obj) {
+               if(obj.getGPUObject(getGPUContext(0)) == null)
                                return 0;
                        else
-                               return 
obj.getGPUObject(getGPUContext(0)).getPointerAddress();
+                               return 
obj.getGPUObject(getGPUContext(0)).getDensePointerAddress();
        }
-
+       
+       public CSRPointer getGPUSparsePointerAddress(MatrixObject obj) {
+               if(obj.getGPUObject(getGPUContext(0)) == null)
+                       throw new RuntimeException("No CSRPointer for 
MatrixObject " + obj.toString());
+               else
+                       return 
obj.getGPUObject(getGPUContext(0)).getJcudaSparseMatrixPtr();
+       }
+       
        public MatrixObject getMatrixInputForGPUInstruction(String varName, 
String opcode) {
                GPUContext gCtx = getGPUContext(0);
                MatrixObject mo = getMatrixObject(varName);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
index 20f4333..6f3a1f7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/SpoofCUDAInstruction.java
@@ -31,6 +31,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.lineage.LineageTraceable;
@@ -88,21 +89,46 @@ public class SpoofCUDAInstruction extends GPUInstruction 
implements LineageTrace
 
                // set the output dimensions to the hop node matrix dimensions
                if( _out.getDataType() == Types.DataType.MATRIX) {
-                       long rows = inputs.get(0).getNumRows();
-                       long cols = inputs.get(0).getNumColumns();
+                       long out_rows = 
ec.getMatrixObject(_out.getName()).getNumRows(); 
+                       long out_cols = 
ec.getMatrixObject(_out.getName()).getNumColumns();
+                       
                        if(_op.getSpoofTemplateType().contains("CW"))
                                
if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == 
SpoofCellwise.CellType.COL_AGG)
-                                       rows = 1;
+                                       out_rows = 1;
                                else 
if(((CNodeCell)_op.getCNodeTemplate()).getCellType() == 
SpoofCellwise.CellType.ROW_AGG)
-                                       cols = 1;
-
-                       MatrixObject out_obj = 
ec.getDenseMatrixOutputForGPUInstruction(_out.getName(), rows, cols).getKey();
-                       ec.setMetaData(_out.getName(), out_obj.getNumRows(), 
out_obj.getNumColumns());
-                       _op.execute(inputs, scalars, out_obj, ec);
+                                       out_cols = 1;
+                       
+                       
+                       if(_op.getSpoofTemplateType().contains("RA")) {
+                               // ToDo: make this workaround proper!!
+                               boolean isConstDim2 = false;
+                               int pos = _op.src.indexOf("// ConstDim2: ");
+                               String strDim2 = _op.src.substring(pos + 14, 
_op.src.indexOf(System.lineSeparator(), pos));
+                               int dim2 = Integer.parseInt(strDim2);
+                               if(dim2 > 0)
+                                       isConstDim2 = true;
+                               
+                               long n = inputs.get(0).getNumColumns();
+                               long n2 = isConstDim2 ? dim2 : 
inputs.get(0).getNumRows();
+                               if(_op.src.contains("COL_AGG_B1_T")) {
+                                       out_rows = n;
+                                       out_cols = n2;
+                               }
+                               
+                       }
+                       ec.setMetaData(_out.getName(), out_rows, out_cols);
+                       GPUObject g = 
inputs.get(0).getGPUObject(ec.getGPUContext(0));
+                       boolean sparseOut = g.isSparse() && 
_op.getSpoofTemplateType().contains("CW");
+                       MatrixObject out_obj = sparseOut ?
+                               
(ec.getSparseMatrixOutputForGPUInstruction(_out.getName(), out_rows, out_cols, 
g.getNnz(getOpcode(), false)).getKey()) :
+                               
(ec.getDenseMatrixOutputForGPUInstruction(_out.getName(), out_rows, 
out_cols).getKey());
+                       
+//                     ec.setMetaData(_out.getName(), out_obj.getNumRows(), 
out_obj.getNumColumns());
+                       _op.execute(inputs, scalars, out_obj, ec, sparseOut);
                        ec.releaseMatrixOutputForGPUInstruction(_out.getName());
                }
                else if (_out.getDataType() == Types.DataType.SCALAR) {
-                       ScalarObject out = new DoubleObject(_op.execute(inputs, 
scalars, null, ec));
+                       ScalarObject out = new DoubleObject(_op.execute(inputs, 
scalars, null, ec, false));
                        ec.setScalarOutput(_out.getName(), out);
                }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
index b2967bd..f59ab95 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUObject.java
@@ -83,7 +83,7 @@ public class GPUObject {
        protected boolean writeLock = false;
 
        /**
-        * Timestamp, needed by {@link GPUContext#evict(long)}
+        * Timestamp, needed by {@link GPUContext\#evict(long)}
         */
        AtomicLong timestamp = new AtomicLong();
 
@@ -1006,7 +1006,7 @@ public class GPUObject {
                return sb.toString();
        }
 
-       private static long getPointerAddress(Pointer p) {
+       private static long getPointerAddressInternal(Pointer p) {
                // WORKAROUND until a method like CUdeviceptr#getAddress exists 
in jCuda
                class PointerWithAddress extends Pointer
                {
@@ -1021,8 +1021,11 @@ public class GPUObject {
                }
                return new PointerWithAddress(p).getAddress();
        }
-
-       public long getPointerAddress() {
-               return getPointerAddress(getDensePointer());
+       
+       public long getDensePointerAddress() {
+               return getPointerAddressInternal(getDensePointer());
        }
-}
+       
+       public static long getPointerAddress(Pointer p) {
+               return getPointerAddressInternal(p);
+       }}

Reply via email to