This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new b651db516d [SYSTEMDS-3669] Builtin for computation of shapley values
b651db516d is described below

commit b651db516da31222018396bb996b3d825766c7da
Author: louislepage <[email protected]>
AuthorDate: Sun Oct 27 17:22:35 2024 +0100

    [SYSTEMDS-3669] Builtin for computation of shapley values
    
    Closes #1946.
---
 scripts/builtin/shapExplainer.dml                  | 732 +++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../builtin/part2/BuiltinShapExplainerTest.java    | 156 +++++
 .../functions/builtin/shapExplainerComponent.dml   |  57 ++
 .../functions/builtin/shapExplainerUnit.dml        | 104 +++
 5 files changed, 1050 insertions(+)

diff --git a/scripts/builtin/shapExplainer.dml 
b/scripts/builtin/shapExplainer.dml
new file mode 100644
index 0000000000..626dc7da4c
--- /dev/null
+++ b/scripts/builtin/shapExplainer.dml
@@ -0,0 +1,732 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Computes shapley values for multiple instances in parallel using antithetic 
permutation sampling.
+# The resulting matrix phis holds the shapley values for each feature in the 
column given by the index of the feature in the sample.
+#
+# This method first creates two large matrices for masks and masked background 
data for all permutations and
+# then runs in paralell on all instances in x.
+# While the prepared matrices can become very large (2 * #features * 
#permuations * #n_samples * #features),
+# the preparation of a row for the model call breaks down to a single 
element-wise multiplication of this mask with the row and
+# an addition to the masked background data, since masks can be reused for 
each instance.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# model_function  The function of the model to be evaluated as a String. This 
function has to take a matrix of samples
+#                 and return a vector of predictions.
+#                 It might be usefull to wrap the model into a function the 
takes and returns the desired shapes and
+#                 use this wrapper here.
+# model_args      Arguments in order for the model, if desired. This will be 
prepended by the created instances-matrix.
+# x_instances     Multiple instances as rows for which to compute the shapley 
values.
+# X_bg            The background dataset from which to pull the random samples 
to perform Monte Carlo integration.
+# n_permutations  The number of permutaions. Defaults to 10. Theoretical 1 
should already be enough for models with up
+#                 to second order interaction effects.
+# n_samples       Number of samples from X_bg used for marginalization.
+# remove_non_var  EXPERIMENTAL: If set, for every instance the varaince of 
each feature is checked against this feature in the
+#                 background data. If it does not change, we do not run any 
model cals for it.
+# seed            A seed, in case the sampling has to be deterministic.
+# verbose         A boolean to enable logging of each step of the function.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# S              Matrix holding the shapley values along the cols, one row per 
instance.
+# expected       Double holding the average prediction of all instances.
+# -----------------------------------------------------------------------------
+s_shapExplainer = function(String model_function, list[unknown] model_args, 
Matrix[Double] x_instances,
+    Matrix[Double] X_bg, Integer n_permutations = 10, Integer n_samples = 100, 
Integer remove_non_var=0,
+    Matrix[Double] partitions=as.matrix(-1), Integer seed = -1, Integer 
verbose = 0)
+  return (Matrix[Double] row_phis, Double expected)
+{
+  u_printShapMessage("Parallel Permutation Explainer for "+nrow(x_instances)+" 
rows.", verbose)
+  u_printShapMessage("Number of Features: "+ncol(x_instances), verbose )
+  total_preds=ncol(x_instances)*2*n_permutations*n_samples*nrow(x_instances)
+  u_printShapMessage("Number of predictions: "+toString(total_preds)+" in 
"+nrow(x_instances)+
+    " parallel cals.", verbose )
+
+  #start with all features
+  features=u_range(1, ncol(x_instances))
+
+  #handle partitions
+  if(sum(partitions) != -1){
+    if(remove_non_var != 0){
+      stop("shapley_permutations_by_row:ERROR: Can't use n_non_varying_inds 
and partitions at the same time.")
+    }
+    features=removePartitionsFromFeatures(features, partitions)
+    
reduced_total_preds=ncol(features)*2*n_permutations*n_samples*nrow(x_instances)
+    u_printShapMessage("Using Partitions reduces number of features to 
"+ncol(features)+".", verbose )
+    u_printShapMessage("Total number of predictions reduced by 
"+(total_preds-reduced_total_preds)/total_preds+" to "+reduced_total_preds+".", 
verbose )
+  }
+
+  #lengths and offsets
+  total_features = ncol(x_instances)
+  perm_length = ncol(features)
+  full_mask_offset = perm_length * 2 * n_samples
+  n_partition_features = total_features - perm_length
+
+  #sample from X_bg
+  u_printShapMessage("Sampling from X_bg", verbose )
+  # could use new samples for each permutation by sampling 
n_samples*n_permutations
+  X_bg_samples = u_sample_with_potential_replace(X_bg=X_bg, samples=n_samples, 
seed=seed )
+  row_phis     = matrix(0, rows=nrow(x_instances), cols=total_features)
+  expected_m   = matrix(0, rows=nrow(x_instances), cols=1)
+
+  #prepare masks for all permutations, since it stays the same for every row
+  u_printShapMessage("Preparing reusable intermediate masks.", verbose )
+  permutations = matrix(0, rows=n_permutations, cols=perm_length)
+  masks_for_permutations = matrix(0, 
rows=perm_length*2*n_permutations*n_samples, cols=total_features)
+
+  parfor (i in 1:n_permutations, check=0){
+    #shuffle features to get permutation
+    permutations[i] = t(u_shuffle(t(features)))
+    perm_mask = prepare_mask_for_permutation(permutation=permutations[i], 
partitions=partitions)
+
+    offset_masks = (i-1) * full_mask_offset + 1
+    
masks_for_permutations[offset_masks:offset_masks+full_mask_offset-1]=prepare_full_mask(perm_mask,
 n_samples)
+  }
+
+  #replicate background and mask it, since it also can stay the same for every 
row
+  # could use new samples for each permutation by sampling 
n_samples*n_permutations and telling this function about it
+  masked_bg_for_permutations = prepare_masked_X_bg(masks_for_permutations, 
X_bg_samples, 0)
+  u_printShapMessage("Computing phis in parallel.", verbose )
+
+  #enable spark execution for parfor if desired
+  #TODO allow spark mode via parameter?
+  #parfor (i in 1:nrow(x_instances), opt=CONSTRAINED, mode=REMOTE_SPARK){
+
+  parfor (i in 1:nrow(x_instances)){
+    if(remove_non_var == 1){
+      # try to remove inds that do not vary from the background
+      non_var_inds = get_non_varying_inds(x_instances[i], X_bg_samples)
+      # only remove if more than 2 features remain, less then two breaks 
removal procedure
+      if (ncol(x_instances) > length(non_var_inds)+2){
+        #remove samples and masks for non varying features
+        [i_masks_for_permutations, i_masked_bg_for_permutations] = 
remove_inds(masks_for_permutations, masked_bg_for_permutations, permutations, 
non_var_inds, n_samples)
+      }else{
+        # we would remove all but two features, whichs breaks the removal 
algorithm
+        non_var_inds = as.matrix(-1)
+        i_masks_for_permutations = masks_for_permutations
+        i_masked_bg_for_permutations = masked_bg_for_permutations
+      }
+    } else {
+      non_var_inds = as.matrix(-1)
+      i_masks_for_permutations = masks_for_permutations
+      i_masked_bg_for_permutations = masked_bg_for_permutations
+    }
+
+    #apply masks and bg data for all permutations at once
+    X_test = apply_full_mask(x_instances[i], i_masks_for_permutations, 
i_masked_bg_for_permutations)
+
+    #generate args for call to model
+    X_arg = append(list(X=X_test), model_args)
+
+    #call model
+    P = eval(model_function, X_arg)
+
+    #compute means, deviding n_rows by n_samples
+    P = compute_means_from_predictions(P=P, n_samples=n_samples)
+
+    #compute phis
+    [phis, e] = compute_phis_from_prediction_means(P=P, 
permutations=permutations, non_var_inds=non_var_inds, 
n_partition_features=n_partition_features)
+    expected_m[i] = e
+
+    #compute phis for this row from all permutations
+    row_phis[i] = t(phis)
+  }
+  #compute expected of model from all rows
+  expected = mean(expected_m)
+}
+
+# Computes which indices do not vary from the background.
+# Uses the appraoch from numpy.isclose() and compares to the largest diff of 
each feature in the bg data.
+# In the futere, more advanced techniques like using std-dev of bg data as a 
tollerance could be used.
+#
+# INPUT:
+# -----------------------------------------------------------------------------
+# x     One single instance.
+# X_bg  Background dataset.
+# -----------------------------------------------------------------------------
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# non_varying_inds A row-vector with all the indices that do not vary from the 
background dataset.
+# -----------------------------------------------------------------------------
+get_non_varying_inds = function(Matrix[Double] x, Matrix[Double] X_bg)
+return (Matrix[Double] non_varying_inds){
+  #from numpy.isclose but adapted to fit MSE of shap, which is within the same 
scale
+  rtol = 1e-04
+  atol = 1e-05
+
+  # compute distance metrics
+  diff = colMaxs(abs(X_bg -x))
+  rdist = atol + rtol * colMaxs(abs(X_bg))
+
+  non_varying_inds = (diff <= rdist)
+  # translate to indices
+  non_varying_inds = t(seq(1,ncol(x))) * non_varying_inds
+  # remove the ones that do vary
+  non_varying_inds = removeEmpty(target=non_varying_inds, margin="cols")
+}
+
+# Prepares a boolean mask for removing features according to permutaion.
+# The resulting matrix needs to be inflated to a sample set by using 
prepare_samples_from_mask() before calling the model.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# permutation         A single permutation of varying features.
+#                     If using partitions, remove them beforhand by using 
removePartitionsFromFeatures() from the utils.
+# n_non_varying_inds  The number of feature that do not vary in the background 
data.
+#                     Can be retrieved e.g. by looking at std.dev
+# partitions          Matrix with first elemnt of partition in first row and 
last element of partition in second row.
+#                     Used to treat partitions as one feature when creating 
masks. Useful for one-hot-encoded features.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# mask           Boolean mask.
+# -----------------------------------------------------------------------------
+prepare_mask_for_permutation = function(Matrix[Double] permutation, Integer 
n_non_varying_inds=0,
+       Matrix[Double] partitions=as.matrix(-1))
+return (Matrix[Double] masks){
+  if(sum(partitions)!=-1){
+    #can't use n_non_varying_inds and partitions at the same time
+    if(n_non_varying_inds > 0){
+      stop("shap-explainer::prepare_mask_for_permutation:ERROR: Can't use 
n_non_varying_inds and partitions at the same time.")
+    }
+    #number of features not in permutation is diff between start and end of 
partitions, since first feature remains in permutation
+    skip_inds = partitions[2,] - partitions[1,]
+
+    #skip these inds by treating them as non varying
+    n_non_varying_inds = sum(skip_inds)
+  }
+
+  #total number of features
+  perm_len = ncol(permutation)+n_non_varying_inds
+  if(n_non_varying_inds > 0){
+    #prep full constructor with placeholders
+    mask_constructor = matrix(perm_len+1, rows=1, cols = perm_len)
+    mask_constructor[1,1:ncol(permutation)] = permutation
+  }else{
+    mask_constructor=permutation
+  }
+
+  perm_cols = ncol(mask_constructor)
+
+  # we compute mask on reverse permutation wnd reverse it later to get desired 
shape
+
+  # create row indicator vector ctable
+  perm_mask_rows = seq(1,perm_cols)
+  #TODO: col-vector and matrix mult?
+  perm_mask_rows = matrix(1, rows=perm_cols, cols=perm_cols) * perm_mask_rows
+  perm_mask_rows = lower.tri(target=perm_mask_rows, diag=TRUE, values=TRUE)
+  perm_mask_rows = removeEmpty(target=matrix(perm_mask_rows, rows=1, 
cols=length(perm_mask_rows)), margin="cols")
+
+  # create column indicator for ctable
+  rev_permutation = t(rev(t(mask_constructor)))
+  #TODO: col-vector and matrix mult?
+  perm_mask_cols = matrix(1, rows=perm_cols, cols=perm_cols) * mask_constructor
+  perm_mask_cols = lower.tri(target=perm_mask_cols, diag=TRUE, values=TRUE)
+  perm_mask_cols = removeEmpty(target = matrix(perm_mask_cols, 
cols=length(perm_mask_cols), rows=1), margin="cols")
+  #ctable
+  masks = table(perm_mask_rows, perm_mask_cols, perm_len, perm_len)
+  if(n_non_varying_inds > 0){
+    #truncate non varying rows
+    masks = masks[1:ncol(permutation)]
+
+    #replicate mask from first feature of each partionton to entire partitions
+    if(sum(partitions)!=-1){
+      for ( i in 1:ncol(partitions) ){
+        p_start = as.scalar(partitions[1,i])
+        p_end   = as.scalar(partitions[2,i])
+        proxy = masks[,p_start] %*% matrix(1, rows=1, cols=p_end-p_start)
+        masks[,p_start+1:p_end] = proxy
+      }
+    }
+  }
+
+  # add inverted mask and revert order for desired shape for forward and 
backward pass
+  masks = rbind(!masks[nrow(masks)],masks, rev(!masks[1:nrow(masks)-1]))
+}
+
+# Prepares the full mask for marginalization by repeating the rows
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# mask                    Boolean mask with 1, where from x, and 0, where 
integrated over background data.
+# n_samples               Number samples for which to replicate.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# x_mask_full             A replicated mask.
+# -----------------------------------------------------------------------------
+prepare_full_mask = function(Matrix[Double] mask, Integer n_samples)
+  return (Matrix[Double] x_mask_full){
+  x_mask_full = u_repeatRows(mask,n_samples)
+}
+
+# Prepares the masked background by replicating the samples and masking them 
using the full mask.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# x_mask_full             Boolean mask replicated orw-wise.
+# X_bg_samples            Samples from background. Either the same n samples 
for all permutaions or
+#                         n*p samples, so each permutation has its own samples.
+# n_perms_in_samples      Number of sample sets to identify block which need 
to be replicated in X_bg_samples.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# x_mask_full             A replicated mask.
+# -----------------------------------------------------------------------------
+prepare_masked_X_bg = function(Matrix[Double] x_mask_full, Matrix[Double] 
X_bg_samples, Integer n_perms_in_samples)
+return (Matrix[Double] masked_X_bg){
+  #Repeat background once for every row in original mask.
+  #If the same samples are used for each permutation, simply repeat the entire 
samples accordingly
+  if (n_perms_in_samples <= 1){
+    #Since x_mask_full was already replicated row-wise by the number of rows 
in X_bg_samples, we devide by it.
+    masked_X_bg = u_repeatMatrix(X_bg_samples, 
nrow(x_mask_full)/nrow(X_bg_samples))
+  }else{
+    # if X_bg_samples has independent samples for each perm, it holds 
n_samples*n_perms rows.
+    block_size = nrow(X_bg_samples)/n_perms_in_samples
+    masked_X_bg = u_repeatMatrixBlocks(X_bg_samples, block_size,  
nrow(x_mask_full)/block_size/n_perms_in_samples)
+  }
+
+  masked_X_bg = masked_X_bg * !x_mask_full
+}
+
+# Applies the masked background and boolen mask to individual instance of 
interest.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# x_row           Instance of interest as row-vector.
+# x_mask_full     Boolean mask replicated orw-wise.
+# masked_X_bg     Prepared background samples.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# X_masked        Set of synthesized instances for x_row.
+# -----------------------------------------------------------------------------
+apply_full_mask = function(Matrix[Double] x_row, Matrix[Double] x_mask_full, 
Matrix[Double] masked_X_bg)
+return (Matrix[Double] X_masked){
+  #add the masked data from this row
+  X_masked = masked_X_bg + (x_mask_full * x_row)
+}
+
+# Removes all rows from the prepared masks and background data whenever their 
feature is marked as non-varying.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# masks               Prepared and replicated mask for a singel instance.
+# masked_X_bg         Prepared and replicated background data.
+# full_permutations   The permutations from which the masks and bd data were 
created.
+# non_var_inds        A row-vector containiing the indices that were found to 
be not varying for this instance.
+# n_samples           The number samples over which each row is integarted.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# sub_mask            A subset of masks where for each permutation the rows 
that correspond to
+#                     non-varying features are removed.
+# sub_masked_X_bg     A subset of the background data where for each 
permutation the rows that correspond to
+#                     non-varying features are removed.
+# -----------------------------------------------------------------------------
+remove_inds = function(Matrix[Double] masks, Matrix[Double] masked_X_bg, 
Matrix[Double] full_permutations,
+  Matrix[Double] non_var_inds, Integer n_samples)
+return(Matrix[Double] sub_mask, Matrix[Double] sub_masked_X_bg){
+  offsets = seq(0,length(full_permutations)-ncol(full_permutations), 
ncol(full_permutations))
+
+  ###
+  # get row indices from permutations
+  total_row_index = full_permutations + offsets
+  total_row_index = matrix(total_row_index, rows=length(total_row_index), 
cols=1)
+
+  row_index = toOneHot(total_row_index, nrow(total_row_index))
+  ####
+  # get indices for all permutations as boolean mask
+  # repeat inds for every permutation
+  non_var_inds = matrix(1, rows=nrow(full_permutations), 
cols=ncol(non_var_inds)) * non_var_inds
+  #add offset
+  non_var_total = non_var_inds + offsets
+  #reshape into col-vec
+  non_var_total = matrix(non_var_total,rows=length(non_var_total), cols=1, 
byrow=FALSE)
+  non_var_mask = toOneHot(non_var_total, nrow(total_row_index))
+
+  non_var_mask = colSums(non_var_mask)
+
+  ###
+  # multiply to get mask
+  non_var_rows = row_index %*% t(non_var_mask)
+
+  ####
+  # unfold to full mask length
+  # reshape to add for each permutations
+  reshaped_rows = matrix(non_var_rows, rows=ncol(full_permutations), 
cols=nrow(full_permutations), byrow=FALSE)
+
+  reshaped_rows_full = matrix(0,rows=1,cols=ncol(reshaped_rows))
+
+  #rbind to manipulate all perms at once
+  if( sum(reshaped_rows[nrow(reshaped_rows)]) > 0 ){
+    #fix last row issue by setting last zero to one, if 1 in last row
+    row_indicator = (!reshaped_rows) * seq(1, nrow(reshaped_rows), 1)
+    row_indicator = colMaxs(row_indicator)
+    row_indicator = t(toOneHot(t(row_indicator), nrow(reshaped_rows)))
+    reshaped_rows_2 = reshaped_rows[1:nrow(reshaped_rows)-1] + 
row_indicator[1:nrow(reshaped_rows)-1]
+    reshaped_rows_full = 
rbind(reshaped_rows_full,reshaped_rows,reshaped_rows_2)
+  }else{
+    reshaped_rows_full = 
rbind(reshaped_rows_full,reshaped_rows,reshaped_rows[1:nrow(reshaped_rows)-1])
+  }
+  #reshape into col-vec
+  non_var_total = matrix(reshaped_rows_full, rows=length(reshaped_rows_full), 
cols=1, byrow=FALSE)
+
+  #replicate, if masks already replicated
+  if (n_samples > 1){
+    non_var_total = matrix(1, rows=nrow(non_var_total), cols=n_samples) * 
non_var_total
+    non_var_total = matrix(non_var_total, rows=length(non_var_total), cols=1)
+  }
+
+  #remove from mask according to this vector
+  sub_mask = removeEmpty(target=masks, select=!non_var_total, margin="rows")
+  #set to 1 where non varying
+  #sub_mask = removed_short_mask | non_var_mask[1, 1:ncol(removed_short_mask)]
+  sub_masked_X_bg = removeEmpty(target=masked_X_bg, select=!non_var_total, 
margin="rows")
+}
+
+# Performs the integration/marginalization by computing means.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# P                     Predictions from model.
+# n_samples   Number of samples over which to take the mean.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# P_means               The means of the sample groups. Each row is one group 
with means in cols.
+# -----------------------------------------------------------------------------
+compute_means_from_predictions = function(Matrix[Double] P, Integer n_samples)
+  return (Matrix[Double] P_means){
+  n_features = nrow(P)/n_samples
+
+  #transpose and reshape to concat all values of same type
+  # TODO: unneccessary for vectors, only t() would be needed
+  P = matrix(t(P), cols=1, rows=length(P))
+
+  #reshape, so all predictions from one batch are in one row
+  P = matrix(P, cols=n_samples, rows=length(P)/n_samples)
+
+  #compute row means
+  P_means = rowMeans(P)
+
+  # reshape and transpose to get back to input dimensions
+  P_means = matrix(P_means, rows=n_features, cols=length(P_means)/n_features)
+}
+
+# Computes phis from predictions for a permutation.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# P                     Predictions for multiple permutations.
+# permutations          Permutations to get the feature indices from.
+# non_var_inds          Matrix holding the indices of non-varying features in 
the permutation that were ignored
+#                       during prediction. These will be remove from the 
<permutations> during computation of the phis.
+# n_partition_features  Number of features that are in partitions - number of 
partitions:
+#                       There is still one feature per partition kept in the 
perms!
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# phis                  Phis or shapley values computed from this permutation.
+#                       Every row holds the phis for the corresponding feature.
+# -----------------------------------------------------------------------------
+compute_phis_from_prediction_means = function(Matrix[Double] P, Matrix[Double] 
permutations,
+  Matrix[Double] non_var_inds=as.matrix(-1), Integer n_partition_features = 0)
+return(Matrix[Double] phis, Double expected){
+  perm_len=ncol(permutations)
+  n_non_var_inds = 0
+  partial_permutations = permutations
+
+  if(sum(non_var_inds)>0){
+    n_non_var_inds = ncol(non_var_inds)
+    #flatten perms to remove from all perms at once
+    perms_flattened = matrix(permutations, rows=length(permutations), cols=1)
+    rem_selector = outer(perms_flattened, non_var_inds, "==")
+    rem_selector = rowSums(rem_selector)
+    partial_permutations = removeEmpty(target=perms_flattened, 
select=!rem_selector, margin="rows")
+    #reshape
+    partial_permutations = matrix(partial_permutations, 
rows=perm_len-n_non_var_inds, cols=nrow(permutations))
+    perm_len = perm_len-n_non_var_inds
+  }
+
+  #reshape P to get one col per permutation
+  P_perm = matrix(P, rows=2*perm_len, cols=nrow(permutations), byrow=FALSE)
+
+  #forwards phis
+  forward_phis = P_perm[2:perm_len+1] - P_perm[1:perm_len]
+
+  #backward phis and fix first and last
+  backward_phis = rbind(P_perm[perm_len+2] - P_perm[1], 
P_perm[perm_len+3:2*perm_len] - P_perm[perm_len+2:2*perm_len-1], 
P_perm[perm_len+1] - P_perm[2*perm_len])
+  #reverse to match order of features in permutation
+  backward_phis = rev(backward_phis)
+  #avg forward and backward
+  forward_phis = matrix(forward_phis, rows=length(forward_phis), cols=1, 
byrow=FALSE)
+  backward_phis = matrix(backward_phis, rows=length(backward_phis), cols=1, 
byrow=FALSE)
+  avg_phis = (forward_phis + backward_phis) / 2
+
+  #aggregate to get only one phi per feature (and implicitly add zeros for non 
var inds)
+  perms_flattened = matrix(partial_permutations, 
rows=length(partial_permutations), cols=1)
+  phis = aggregate(target=avg_phis, groups=perms_flattened, fn="mean", 
ngroups=ncol(permutations)+n_partition_features)
+
+  #get expected from first row
+  expected=mean(P_perm[1])
+}
+
+# Removes features that are part of a partition.
+# Keeps first feature of partition as proxy for partition.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# features        Matrix holding features in its cols.
+# partitions      Matirx holding start and end of partitions in the cols of 
the first and second row respectively.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# short_features  Matrix like fatures, but with the ones from partitiones 
removed.
+# -----------------------------------------------------------------------------
+removePartitionsFromFeatures = function(Matrix[Double] features, 
Matrix[Double] partitions)
+return (Matrix[Double] short_features){
+  #remove from features
+  rm_mask = matrix(0, rows=1, cols=ncol(features))
+  for (i in 1:ncol(partitions)){
+    part_start = as.scalar(partitions[1,i])
+    part_end   = as.scalar(partitions[2,i])
+    #include part_start as proxy of partition
+    rm_mask = rm_mask + (features > part_start) * (features <= part_end)
+  }
+  short_features = removeEmpty(target=features, margin="cols", select=!rm_mask)
+}
+
+########################
+# Utility Functions that might be worth refactoring into its own file
+# They could be used in other scenarios as well
+########################
+
+
+# Samples from the background data X_bg.
+# The function first uses all background samples without replacement, but if 
more samples are requested than
+# available in X_bg, it shuffles X_bg and pulls more samples from it, making 
it sampling with replacement.
+# TODO: Might be replacable by other builtin for sampling in the future
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# X_bg            Matrix of background data
+# samples         Number of total samples
+# always_shuffle  Boolean to enable reshuffleing of X_bg, defaults to false.
+# seed            A seed for the shuffleing etc.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# X_sample        New Matrix containing #samples, from X_bg, potentially with 
replacement.
+# -----------------------------------------------------------------------------
+u_sample_with_potential_replace = function(Matrix[Double] X_bg, Integer 
samples, Boolean always_shuffle = 0, Integer seed)
+return (Matrix[Double] X_sample){
+  number_of_bg_samples = nrow(X_bg)
+
+  # expect to not use all from background and subsample from it
+  num_of_full_X_bg = 0
+  num_of_remainder_samples = samples
+
+  # shuffle background if desired
+  if(always_shuffle) {
+  X_bg = u_shuffle(X_bg)
+  }
+
+  # list to store references to generated matrices so we can rbind them in one 
call
+  samples_list = list()
+
+  # in case we need more than in the background data, use it multiple times 
with replacement
+  if(samples >= number_of_bg_samples)  {
+    u_printShapMessage("WARN: More samples ("+toString(samples)+") are 
requested than available in the background dataset 
("+toString(number_of_bg_samples)+"). Using replacement", 1)
+
+    # get number of full sets of background by integer division
+    num_of_full_X_bg = samples %/% number_of_bg_samples
+    # get remaining samples using modulo
+    num_of_remainder_samples = samples %% number_of_bg_samples
+
+    #use background data once
+    samples_list = append(samples_list, X_bg)
+
+    if(num_of_full_X_bg > 1){
+      # add shuffled versions of background data
+      for (i in 1:num_of_full_X_bg-1){
+      samples_list = append(samples_list, u_shuffle(X_bg))
+      }
+    }
+  }
+
+  # sample from background dataset for remaining samples
+  if (num_of_remainder_samples > 0){
+    # pick remaining samples
+    random_samples_indices = sample(number_of_bg_samples, 
num_of_remainder_samples, seed)
+
+    #contingency table to pick rows by multiplication
+    R_cont = table(random_samples_indices, random_samples_indices, 
number_of_bg_samples, number_of_bg_samples)
+
+    #pick samples by multiplication with contingency table of indices and 
removing empty rows
+    samples_list = append(samples_list, removeEmpty(target=t(t(X_bg) %*% 
R_cont), margin="rows"))
+  }
+
+
+  if ( length(samples_list) == 1){
+    #dont copy if only one matrix is in list, since this is a heavy hitter
+    X_sample = as.matrix(samples_list[1])
+  } else {
+    #single call to bind all generated samples into one large matrix
+    X_sample = rbind(samples_list)
+  }
+}
+
+# Simple utility function to shuffle (from shuffle.dml, but without storing to 
file). Shuffles rows.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# X               Matrix to be shuffled
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# X_shuffled      Matrix like X but ... shuffled...
+# -----------------------------------------------------------------------------
+u_shuffle = function(Matrix[Double] X)
+return (Matrix[Double] X_shuffled){
+  num_col = ncol(X)
+  # Random vector used to shuffle the dataset
+  y = rand(rows=nrow(X), cols=1, min=0, max=1, pdf="uniform")
+  X = order(target = cbind(X, y), by = num_col + 1)
+  X_shuffled = X[,1:num_col]
+}
+
+# Simple utility function to create a range of integers from start to end.
+#
+# INPUT:
+# 
---------------------------------------------------------------------------------------
+# start           First integer of range.
+# stop            First integer of range.
+# 
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# range           Matrix with range from start to end in its cols.
+# -----------------------------------------------------------------------------
+u_range = function(Integer start, Integer end)
+return (Matrix[Double] range){
+  range = t(cumsum(matrix(1, rows=end-start+1, cols=1)))
+  range = range+start-1
+}
+
+# Replicates rows of the input matrix n-times.
+#
+# Example:
+# [1,2]
+# [3,4]
+# becomes
+# [1,2]
+# [1,2]
+# [3,4]
+# [3,4]
+#
+# INPUT:
+# -----------------------------------------------------------------------------
+# M        Matrix where rows will be replicated.
+# n_times  Number of replications.
+# -----------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# M         Matrix of replicated rows.
+# -----------------------------------------------------------------------------
+u_repeatRows = function(Matrix[Double] M, Integer n_times)
+return(Matrix[Double] M){
+  #get indices for new rows (e.g. 1,1,1,2,2,2 for 2 rows, each replicated 3 
times)
+  indices = ceil(seq(1,nrow(M)*n_times,1) / n_times)
+
+  #to one hot, so we get a replication matrix R
+  R = toOneHot(indices, nrow(M))
+
+  #matrix-mulitply to repeat rows
+  M = R %*% M
+}
+
+# Replicates matrix n-times block-wise.
+#
+# Example:
+# [1,2]
+# [3,4]
+# becomes
+# [1,2]
+# [3,4]
+# [1,2]
+# [3,4]
+#
+# INPUT:
+# -----------------------------------------------------------------------------
+# M        Matrix where rows will be replicated.
+# n_times  Number of replications.
+# -----------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# M         Matrix of replicated rows.
+# -----------------------------------------------------------------------------
+u_repeatMatrix = function(Matrix[Double] M, Integer n_times)
+return(Matrix[Double] M){
+  n_rows=nrow(M)
+  n_cols=ncol(M)
+  #reshape to row vector
+  M = matrix(M, rows=1, cols=length(M))
+  #broadcast
+  M = matrix(1, rows=n_times, cols=1) * M
+  #reshape to get matrix
+  M = matrix(M, rows=n_rows*n_times, cols=n_cols)
+}
+
+# Like repeatMatrix(), but alows to define parts of matrix as blocks to 
replicate n-rows as a block.
+u_repeatMatrixBlocks = function(Matrix[Double] M, Integer rows_per_block, 
Integer n_times)
+return(Matrix[Double] M){
+  n_rows=nrow(M)
+  n_cols=ncol(M)
+  #reshape to row vector
+  M = matrix(M, rows=n_rows/rows_per_block, cols=n_cols*rows_per_block)
+  #repeat block rows
+  M = u_repeatRows(M, n_times)
+  #reshape to get matrix
+  M = matrix(M, rows=n_rows*n_times, cols=n_cols)
+}
+
+#utility function to print with shap-explainer-tag
+u_printShapMessage = function(String message, Boolean verbose){
+  if(verbose){
+  print("shap-explainer::"+message)
+  }
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 98f92ae55e..ca31cd331a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -300,6 +300,7 @@ public enum Builtins {
        SELVARTHRESH("selectByVarThresh", true),
        SEQ("seq", false),
        SYMMETRICDIFFERENCE("symmetricDifference", true),
+       SHAPEXPLAINER("shapExplainer", true),
        SHERLOCK("sherlock", true),
        SHERLOCKPREDICT("sherlockPredict", true),
        SHORTESTPATH("shortestPath", true),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java
new file mode 100644
index 0000000000..0c207d66f3
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class BuiltinShapExplainerTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME = "shapExplainer";
+       private static final String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinShapExplainerTest.class.getSimpleName() + "/";
+
+       //FIXME need for padding result with zero
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
+       }
+
+       @Test
+       public void testPrepareMaskForPermutation() {
+               runShapExplainerUnitTest("prepare_mask_for_permutation");
+       }
+
+       @Test
+       public void testPrepareMaskForPartialPermutation() {
+               
runShapExplainerUnitTest("prepare_mask_for_partial_permutation");
+       }
+
+       @Test
+       public void testPrepareMaskForPartitionedPermutation() {
+               
runShapExplainerUnitTest("prepare_mask_for_partitioned_permutation");
+       }
+
+       @Test
+       public void testComputeMeansFromPredictions() {
+               runShapExplainerUnitTest("compute_means_from_predictions");
+       }
+
+       @Test
+       public void testComputePhisFromPredictionMeans() {
+               runShapExplainerUnitTest("compute_phis_from_prediction_means");
+       }
+
+       @Test
+       public void testComputePhisFromPredictionMeansNonVars() {
+               
runShapExplainerUnitTest("compute_phis_from_prediction_means_non_vars");
+       }
+
+       @Test
+       public void testPrepareFullMask() {
+               runShapExplainerUnitTest("prepare_full_mask");
+       }
+
+       @Test
+       public void testPrepareMaskedXBg() {
+               runShapExplainerUnitTest("prepare_masked_X_bg");
+       }
+
+       @Test
+       public void testPrepareMaskedXBgIndependentPerms() {
+               
runShapExplainerUnitTest("prepare_masked_X_bg_independent_perms");
+       }
+
+       @Test
+       public void testApplyFullMask() {
+               runShapExplainerUnitTest("apply_full_mask");
+       }
+
+       private void runShapExplainerUnitTest(String testType) {
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
+               loadTestConfiguration(getTestConfiguration(TEST_NAME));
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       //execute given unit test
+                       fullDMLScriptName = HOME + TEST_NAME + "Unit.dml";
+                       programArgs = new String[]{"-args", testType, 
output("R"), output("R_expected")};
+                       runTest(true, false, null, -1);
+
+                       //compare to expected result
+                       HashMap<CellIndex, Double> result = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<CellIndex, Double> result_expected = 
readDMLMatrixFromOutputDir("R_expected");
+
+                       TestUtils.compareMatrices(result, result_expected, 
1e-3, testType+"_result", testType+"_expected");
+
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+
+       @Test
+       public void testShapExplainerDummyData(){
+               runShapExplainerComponentTest(false);
+       }
+       //TODO add test with real data
+
+       private void runShapExplainerComponentTest(Boolean useRealData) {
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
+               loadTestConfiguration(getTestConfiguration(TEST_NAME));
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       //execute given unit test
+                       fullDMLScriptName = HOME + TEST_NAME + "Component.dml";
+                       programArgs = new String[]{"-args", output("R"), 
output("R_expected")};
+                       runTest(true, false, null, -1);
+
+                       //compare to expected phis
+                       HashMap<CellIndex, Double> result = 
readDMLMatrixFromOutputDir("R_phis");
+                       HashMap<CellIndex, Double> result_expected = 
readDMLMatrixFromOutputDir("R_expected_phis");
+
+                       TestUtils.compareMatrices(result, result_expected, 
1e-3, "explainer_result_phis", "explainer_expected_phis");
+
+                       //compare to expected value of model
+                       HashMap<CellIndex, Double> result_e = 
readDMLMatrixFromOutputDir("R_e");
+                       HashMap<CellIndex, Double> result_expected_e = 
readDMLMatrixFromOutputDir("R_expected_e");
+
+                       TestUtils.compareMatrices(result_e, result_expected_e, 
1e-3, "explainer_result_e", "explainer_expected_e");
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/shapExplainerComponent.dml 
b/src/test/scripts/functions/builtin/shapExplainerComponent.dml
new file mode 100644
index 0000000000..8bf444b665
--- /dev/null
+++ b/src/test/scripts/functions/builtin/shapExplainerComponent.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+########################################################################################################
+# THIS TEST IS HIGHLY DEPENDANT ON THE SAMPING!
+# Changes in the dataset or number of samples etc. migh already be enough to 
change the expected result.
+########################################################################################################
+
+model_args = list(mult=1)
+x_instances = matrix("100 200 300    100 300 400    100 100 500", rows=3, 
cols=3)
+X_bg = matrix("11 12  13    21 22 23    31 32 33    41 42 43", rows=4, cols=3)
+n_permutations = 2
+n_samples = 3
+seed = 42
+
+#model for explainer test
+dummyModel = function(Matrix[Double] X, Double mult)
+  return(Matrix[Double] P){
+  P = rowSums(X)*mult
+}
+
+[result_phis, result_e] = shapExplainer("dummyModel", model_args, x_instances, 
X_bg, n_permutations, n_samples, 0, as.matrix(-1), seed, 1)
+result_e = cbind(as.matrix(result_e), as.matrix(0))
+#TODO for some reason storing just the scalar results in errors, so we create 
a small matrix by padding with a zero.
+# Might be due to comma vs dot separation of decimals in strings if systems 
uses german local or other.
+
+expected_result_phis = matrix("69 168 267 69 268 367 69 68 467", rows=3, 
cols=3)
+expected_result_e = matrix("96 0", rows=1, cols=2)
+
+path_phis=$1+"_phis"
+path_e=$1+"_e"
+path_expected_phis=$2+"_phis"
+path_expected_e=$2+"_e"
+
+write(result_phis, path_phis)
+write(result_e, path_e)
+write(expected_result_phis, path_expected_phis)
+write(expected_result_e, path_expected_e)
+
diff --git a/src/test/scripts/functions/builtin/shapExplainerUnit.dml 
b/src/test/scripts/functions/builtin/shapExplainerUnit.dml
new file mode 100644
index 0000000000..c5f227a67e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/shapExplainerUnit.dml
@@ -0,0 +1,104 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("scripts/builtin/shapExplainer.dml") as shap;
+
+if ($1 == 'prepare_mask_for_permutation') {
+  #prepare_mask_for_permutation
+  perm = matrix("3 1 2", cols=3, rows=1)
+  result = shap::prepare_mask_for_permutation(permutation=perm)
+  expected_result = matrix("0 0 0 0 0 1 1 0 1 1 1 1 0 1 0 1 1 0", rows=6, 
cols=3)
+
+} else if ($1 == 'prepare_mask_for_partial_permutation') {
+  #prepare_mask_for_partial_permutation
+  perm = matrix("4 1 2", cols=3, rows=1)
+  result = shap::prepare_mask_for_permutation(permutation=perm, 
n_non_varying_inds=2)
+  expected_result = matrix("0 0 1 0 1   0 0 0 1 0   1 0 0 1 0   1 1 0 1 0    0 
1 1 0 1  1 1 1 0 1", rows=6, cols=5)
+
+} else if ($1 == 'prepare_mask_for_partitioned_permutation') {
+  #prepare_mask_for_partitioned_permutation
+  perm = matrix("4 1 2", cols=3, rows=1)
+  partitions = matrix("2 4 3 5", cols=2, rows=2)
+  result = shap::prepare_mask_for_permutation(permutation=perm, 
partitions=partitions)
+  expected_result = matrix("0 0 0 0 0   0 0 0 1 1   1 0 0 1 1   1 1 1 1 1   0 
1 1 0 0   1 1 1 0 0", rows=6, cols=5)
+
+} else if ($1 == 'compute_means_from_predictions') {
+  #compute_means_from_predictions
+  p = matrix("2 3 3 4 4 5", rows=6, cols=1)
+  result = shap::compute_means_from_predictions(p, 2)
+  expected_result = matrix("2.5 3.5 4.5", rows=3, cols=1)
+
+} else if ($1 == 'compute_phis_from_prediction_means') {
+  #compute_phis_from_prediction_means
+  permutation = matrix("2 3 4 1 5", cols=5, rows=1)
+  P_perm = matrix("10 21 22 23 24 100 31 32 33 34", rows=10, cols=1)
+  result = shap::compute_phis_from_prediction_means(P=P_perm, 
permutations=permutation)
+  expected_result = matrix("1 38.5 1 1 48.5", rows=5, cols=1)
+
+} else if ($1 == 'compute_phis_from_prediction_means_non_vars') {
+  #compute_phis_from_prediction_means with non varying inds
+  permutation = matrix("3 4 2 1 5", cols=5, rows=1)
+  non_varying_inds= matrix("2", rows=1, cols=1)
+  P_perm = matrix("10 22 23 24 100 31 32 33", rows=8, cols=1)
+  result = shap::compute_phis_from_prediction_means(P=P_perm, 
permutations=permutation, non_var_inds=non_varying_inds)
+  expected_result = matrix("1 0 39.5 1 48.5", rows=5, cols=1)
+
+} else if ($1 == 'prepare_full_mask') {
+  #prepare_full_mask
+  mask = matrix("1 0 0 1", rows=2, cols=2)
+  result = shap::prepare_full_mask(mask, 3)
+  result = shap::u_repeatRows(mask,3)
+  expected_result = matrix("1 0 1 0 1 0 0 1 0 1 0 1", rows=6, cols=2)
+
+} else if ($1 == 'prepare_masked_X_bg') {
+  #prepare_masked_X_bg
+  mask = matrix("1 0 0 1", rows=2, cols=2)
+  full_mask = shap::prepare_full_mask(mask, 3)
+  X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2)
+  result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0)
+  expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0", rows=6, cols=2)
+
+} else if ($1 == 'prepare_masked_X_bg_independent_perms') {
+  #prepare_masked_X_bg for independent perms
+  mask = matrix("1 0 0 1 1 0 0 1", rows=4, cols=2)
+  full_mask = shap::prepare_full_mask(mask, 3)
+  X_bg_samples = matrix("11 12  21 22  31 32  41 42  51 52  61 62", rows=6, 
cols=2)
+  result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 2)
+  expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0  0 42  0 52  0 62 41 
0  51 0  61 0", rows=12, cols=2)
+
+} else if ($1 == 'apply_full_mask') {
+  #apply_full_mask
+  x_row = matrix("100 200", rows=1, cols=2)
+  mask = matrix("1 0 0 1", rows=2, cols=2)
+  full_mask = shap::prepare_full_mask(mask, 3)
+  X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2)
+  masked_X_bg = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0)
+  result = shap::apply_full_mask(x_row, full_mask, masked_X_bg)
+  expected_result = matrix("100 12 100 22 100 32 11 200 21 200 31 200", 
rows=6, cols=2)
+
+} else {
+  print("Test type "+$1+" unknown.")
+  result = matrix("100 100", rows=1, cols=2)
+  expected_result = matrix("0 0", rows=1, cols=2)
+}
+
+write(result, $2)
+write(expected_result, $3)
+

Reply via email to