This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new ad89941 [SYSTEMDS-2869] Built-in functions KNN and KNNBF (k-nearest
neighbors)
ad89941 is described below
commit ad899416d3d583df9bf4c6c16755e0bd39382385
Author: ywcb00 <[email protected]>
AuthorDate: Sun Feb 21 00:30:46 2021 +0100
[SYSTEMDS-2869] Built-in functions KNN and KNNBF (k-nearest neighbors)
DIA project WS2020/21.
Closes #2869.
Co-authored-by: Metka Batič <[email protected]>
Co-authored-by: Matthias Kargl <[email protected]>
---
.github/workflows/functionsTests.yml | 3 +-
scripts/builtin/knn.dml | 639 +++++++++++++++++++++
scripts/builtin/knnbf.dml | 58 ++
.../java/org/apache/sysds/common/Builtins.java | 4 +-
src/test/java/org/apache/sysds/test/TestUtils.java | 409 ++++++-------
.../test/functions/builtin/BuiltinKNNBFTest.java | 118 ++++
.../test/functions/builtin/BuiltinKNNTest.java | 130 +++++
src/test/scripts/functions/builtin/knn.R | 52 ++
src/test/scripts/functions/builtin/knn.dml | 35 ++
src/test/scripts/functions/builtin/knnbf.dml | 28 +
.../scripts/functions/builtin/knnbfReference.dml | 29 +
src/test/scripts/installDependencies.R | 2 +
12 files changed, 1312 insertions(+), 195 deletions(-)
diff --git a/.github/workflows/functionsTests.yml
b/.github/workflows/functionsTests.yml
index cf64a2f..5e7466c 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -46,7 +46,8 @@ jobs:
"**.functions.builtin.**",
"**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**",
"**.functions.dnn.**,**.functions.misc.**,**.functions.mlcontext.**,**.functions.paramserv.**",
-
"**.functions.nary.**,**.functions.parfor.**,**.functions.pipelines.**,**.functions.privacy.**,**.functions.quaternary.**,**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
+ "**.functions.nary.**,**.functions.quaternary.**",
+
"**.functions.parfor.**,**.functions.pipelines.**,**.functions.privacy.**,**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
"**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**,**.functions.transform.**",
"**.functions.unary.matrix.**"
]
diff --git a/scripts/builtin/knn.dml b/scripts/builtin/knn.dml
new file mode 100644
index 0000000..8e86ba3
--- /dev/null
+++ b/scripts/builtin/knn.dml
@@ -0,0 +1,639 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# THIS SCRIPT IMPLEMENTS KNN( K Nearest Neighbor ) ALGORITHM
+#
+# INPUT PARAMETERS:
+#
---------------------------------------------------------------------------------------------
+# NAME TYPE DEFAULT OPTIONAL MEANING
+#
---------------------------------------------------------------------------------------------
+# Train Matrix --- N The input matrix as features
+# Test Matrix --- N The input matrix for nearest
neighbor search
+# CL Matrix --- Y The input matrix as target
+# CL_T Integer 0 Y The target type of matrix CL
whether
+# columns in CL are continuous ( =1
) or
+# categorical ( =2 ) or
+# not specified ( =0 )
+# trans_continuous Boolean FALSE Y Option flag for continuous feature
transformed to [-1,1]:
+# FALSE = do not transform
continuous variable;
+# TRUE = transform continuous
variable;
+# k_value int 5 Y k value for KNN, ignore if
select_k enable
+# select_k Boolean FALSE Y Use k selection algorithm to
estimate k
+# ( TRUE means yes )
+# k_min int 1 Y Min k value( available if
select_k = 1 )
+# k_max int 100 Y Max k value( available if
select_k = 1 )
+# select_feature Boolean FALSE Y Use feature selection algorithm to
select feature
+# ( TRUE means yes )
+# feature_max int 10 Y Max feature selection
+# interval int 1000 Y Interval value for K selecting (
available if select_k = 1 )
+# feature_importance Boolean FALSE Y Use feature importance
algorithm to estimate each feature
+# ( TRUE means yes )
+# predict_con_tg int 0 Y Continuous target predict
function: mean(=0) or
+# median(=1)
+# START_SELECTED Matrix --- Y feature selection initinal value
+#
---------------------------------------------------------------------------------------------
+# OUTPUT: Matrix NNR, Matrix PR, Matrix FEATURE_IMPORTANCE_VALUE
+#
+
+m_knn = function(
+ Matrix[Double] Train,
+ Matrix[Double] Test,
+ Matrix[Double] CL,
+ Integer CL_T = 0,
+ Integer trans_continuous = 0,
+ Integer k_value = 5,
+ Integer select_k = 0,
+ Integer k_min = 1,
+ Integer k_max = 100,
+ Integer select_feature = 0,
+ Integer feature_max = 10,
+ Integer interval = 1000,
+ Integer feature_importance = 0,
+ Integer predict_con_tg = 0,
+ Matrix[Double] START_SELECTED = matrix(0, 0, 0)
+)return(
+ Matrix[Double] NNR_matrix,
+ Matrix[Double] CL_matrix,
+ Matrix[Double] m_feature_importance
+){
+
+ m_feature_importance = matrix(0, 0, 0);
+
+ #data prepare
+ if( trans_continuous == 1 ){
+ Train = prepareKNNData( Train);
+ Test = prepareKNNData( Test);
+ }
+
+ n_records = nrow( Train );
+ n_features = ncol( Train );
+ s_selected_k = 5;
+ m_selected_feature = matrix(1,rows=1,cols=n_records);
+ if( select_k == 1 | select_feature==1 ){
+ #parameter check
+ #parameter re-define
+ if( select_k==1 ){
+ if( k_max >= n_records ){
+ k_max = n_records - 1;
+ print( "k_max should no greater than number of record, change k_max
equal " +
+ "( number of record - 1 ) : " + k_max );
+ }
+ if( k_max >= interval ){
+ interval = k_max + 1;
+ # k_max should equal interval -1, because we drop self when search nn.
+ print( "interval should be no less than k_max, change interval equal :
" +
+ interval );
+ }
+ if( k_max <= 1 )
+ stop( "uncorrect k_max value" );
+ if( k_min >= k_max )
+ stop( "k_min >= k_max" );
+ }
+ if( select_feature == 1 ){
+ if( k_value >= n_records ){
+ k_value = n_records - 1;
+ print( "k_value should be no greater than number of record, change
k_value equal " +
+ "( number of record - 1 ) : " + k_value );
+ }
+ #Select feature only
+ if( nrow(START_SELECTED) == 0 & ncol(START_SELECTED) == 0 )
+ m_start_selected_feature = matrix( 0,1,n_features );
+ else
+ m_start_selected_feature = START_SELECTED;
+ }
+
+ if( select_k == 1 & select_feature == 1){
+ #Combined k and feature selection
+ print("Start combined k and feature selection ...");
+ [m_selected_feature,s_selected_k] =
+ getSelectedFeatureAndK( Train,CL,CL_T,m_start_selected_feature,
+ feature_max,k_min,k_max,interval );
+ }
+ else if( select_k == 1 ){
+ #Select k only
+ print("Start k select ...");
+ s_selected_k = getSelectedKBase( Train,CL,CL_T,k_min,k_max,interval );
+ }
+ else if( select_feature == 1 ){
+ #Select feature only
+ print("Start feature selection ... ");
+ [m_selected_feature,d_err] =
+ getSelectedFeature( Train,CL,CL_T,m_start_selected_feature,
+ feature_max,k_value,interval );
+ }
+ }
+
+ if( feature_importance == 1){
+ if( k_value >= n_records ){
+ k_value = n_records - 1;
+ print( "k_value should be no greater than number of record, make k_value
equal " +
+ "( number of record - 1 ) : " + k_value );
+ }
+ [m_feature_importance,m_norm_feature_importance] =
+ getFeatureImportance(Train,CL,CL_T,k_value);
+ }
+
+ NNR_matrix = naiveKNNsearch(P=Train,Q=Test,K=k_value);
+
+ CL_matrix = matrix( 0,nrow( Test ),1 );
+
+ for(i in 1 : nrow(NNR_matrix))
+ {
+ NNR_tmp_matrix = matrix( 0,k_value,1 );
+ for( j in 1:k_value )
+ NNR_tmp_matrix[j,1] = CL[as.scalar( NNR_matrix[i,j] ),1];
+
+ if(CL_T == 2) {
+ t_cl_value = as.scalar( rowIndexMax( t(NNR_tmp_matrix) ) );
+ }
+ else {
+ if ( predict_con_tg == 0)
+ t_cl_value = mean( NNR_tmp_matrix );
+ else if(predict_con_tg == 1)
+ t_cl_value = median( NNR_tmp_matrix );
+ }
+
+ CL_matrix[i,1] = t_cl_value;
+ }
+}
+
+#naive knn search implement
+naiveKNNsearch = function(
+ Matrix[Double] P,
+ Matrix[Double] Q,
+ Integer K
+)return(
+ Matrix[Double] O
+){
+ num_records = nrow (P);
+ num_features = ncol (P);
+ num_queries = nrow (Q);
+ Qt = t(Q);
+ PQt = P %*% Qt;
+ P2 = rowSums (P ^ 2);
+ D = -2 * PQt + P2;
+ if (K == 1) {
+ Dt = t(D);
+ O = rowIndexMin (Dt);
+ } else {
+ O = matrix (0, rows = num_queries, cols = K);
+ parfor (i in 1:num_queries) {
+ D_sorted=order(target=D[,i], by=1, decreasing=FALSE, index.return=TRUE);
+ O[i,] = t(D_sorted[1:K,1]);
+ }
+ }
+}
+
+#naive knn search for predict value only implement
+#TODO eliminate redundancy
+naiveKNNsearchForPredict = function(
+ matrix[double] P,
+ matrix[double] Q,
+ matrix[double] L,
+ integer K
+)return(
+ matrix[double] OL
+){
+ num_records = nrow (P);
+ num_features = ncol (P);
+ num_queries = nrow (Q);
+ Qt = t(Q);
+ PQt = P %*% Qt;
+ P2 = rowSums (P ^ 2);
+ D = -2 * PQt + P2;
+ if (K == 1) {
+ Dt = t(D);
+ O = rowIndexMin (Dt);
+ OL = matrix (0, rows = num_queries, cols = 1)
+ parfor( i in 1:num_queries){
+ OL[i,] = L[as.scalar(O[i,1]),1]
+ }
+ } else {
+ OL = matrix (0, rows = num_queries, cols = K);
+ parfor (i in 1:num_queries) {
+ D_sorted=order(target=cbind(D[,i],L), by=1, decreasing=FALSE,
index.return=FALSE);
+ OL[i,] = t(D_sorted[1:K,2]);
+ }
+ }
+}
+
+getErr_k = function ( matrix[double] in_m_neighbor_value,
+ matrix[double] in_m_cl,
+ integer in_i_cl_type ,
+ integer in_i_k_min )
+ return ( matrix[double] out_m_err )
+{
+ i_col = ncol( in_m_neighbor_value );
+ i_row = nrow( in_m_neighbor_value );
+
+ out_m_err = matrix( 0,i_row,i_col - in_i_k_min + 1 );
+ if( in_i_cl_type == 2 ) #category
+ m_correct = in_m_neighbor_value != in_m_cl[1:i_row,];
+ else #continuous
+ m_correct = (in_m_neighbor_value - in_m_cl[1:i_row,])^2;#ppred(
in_m_neighbor_value,in_m_cl,"-" );
+ parfor( i in 1:i_col-in_i_k_min+1 ,check = 0 ){
+ out_m_err[,i] =
+ ( rowSums( m_correct[,1:in_i_k_min + i - 1] ) / ( in_i_k_min + i - 1
) );
+ }
+ #return err for each record and each k ( belong to range 1~max );
+}
+
+eliminateModel = function ( double s_err_mean, double s_err_vars, integer
i_row )
+ return( boolean out_b_inactived ){
+ #alpha, beta, gamma, delta
+ d_gamma = 0.001;
+ d_delta = 0.001;
+ tmp_d_delta = cdf(target = (-d_gamma - s_err_mean)/s_err_vars,
dist="t",df=i_row-1);
+ out_b_inactived = (tmp_d_delta < d_delta)
+}
+
+getErr = function ( matrix[double] in_m_neighbor_value,
+ matrix[double] in_m_cl,
+ integer in_i_cl_type )
+ return ( matrix[double] out_m_err )
+{
+ i_col = ncol( in_m_neighbor_value );
+ i_row = nrow( in_m_neighbor_value );
+ if( in_i_cl_type == 2 ) #category
+ m_correct = in_m_neighbor_value != in_m_cl[1:i_row,];
+ else #continuous
+ m_correct = (in_m_neighbor_value - in_m_cl[1:i_row,])^2;
+ out_m_err = ( rowSums( m_correct[,1:i_col] )/( i_col ) );
+}
+
+# getSelectedFeatureAndK:
+# Combine k and feature selection algorithm.
+# Refer to ADD part "8.Combined k and feature selection"
+# Argument:
+# in_m_data input matrix as features
+# in_m_data_target input matrix as target value
+# in_i_is_categorical 1 = category , 0 = continuous
+# in_m_init_selected S.user, initial selected feature which use
specified
+# in_i_max_select J, max feature selected
+# k_min minimun k
+# k_max maximun k
+# interval block size for BRACE algorithm
+#
+# Reture:
+# out_m_selected_feature output matrix for feature selection
+# out_i_selected_k output k value for k selection
+
+getSelectedFeatureAndK = function (
+ matrix[double] in_m_data,
+ matrix[double] in_m_data_target,
+ integer in_i_is_categorical, # 1 = category , 0 = continuous
+ matrix[double] in_m_init_selected,
+ integer in_i_max_select,
+ integer k_min,
+ integer k_max,
+ integer interval )
+return(
+ matrix[double] out_m_selected_feature,
+ integer out_i_selected_k
+ )
+{
+ m_err = matrix( 0,1,k_max-k_min+1 );
+ m_feature = matrix( 0,k_max-k_min+1,ncol( in_m_data ) );
+ #Step 1. For each k in [k_min,k_max] ( k_min has default value 1, k_max has
default value 100 )
+ #in parallel select relevant features using FS+BRACE or schemata search
described in Section 7.
+ parfor( i in k_min:k_max,check=0 ){
+ [m_selected_feature,d_err] =
+ getSelectedFeature( in_m_data,in_m_data_target,in_i_is_categorical,
+ in_m_init_selected,in_i_max_select,i,interval );
+ m_err[1,i] = d_err;
+ m_feature[i,] = m_selected_feature;
+ }
+ #Step 2. Output the combination of features and k with the smallest LOOCV
error.
+ i_min_err_index = as.integer( as.scalar( rowIndexMin( m_err ) ) );
+ out_m_selected_feature = m_feature[i_min_err_index,];
+ out_i_selected_k = i_min_err_index + k_min - 1;
+}
+
+getFeatureImportance = function (
+ matrix[double] in_m_data,
+ matrix[double] in_m_data_target,
+ integer in_i_is_categorical, # 1 = category , 0 = continuous
+ integer k_value)
+return(
+ matrix[double] out_m_feature_importance,
+ matrix[double] out_m_norm_feature_importance
+ )
+{
+ n_feature = ncol(in_m_data)
+ n_record = nrow(in_m_data)
+ if(n_feature <= 1)
+ stop("can't estimate feature importance when ncol = 1")
+
+ m_err = matrix( 0,n_record,n_feature);
+ for(i_feature in 1:n_feature){
+ m_feature_select = matrix(1,1,n_feature)
+ m_feature_select[1,i_feature] = 0;
+ m_in_tmp_data = removeEmpty(target=in_m_data,margin="cols", select=
m_feature_select)
+ m_neighbor_value = getKNeighbor(
m_in_tmp_data,m_in_tmp_data,in_m_data_target,k_value);
+ m_tmp_err = getErr( m_neighbor_value,in_m_data_target ,in_i_is_categorical
);
+ m_err[,i_feature] = m_tmp_err
+ }
+ out_m_feature_importance = colSums( m_err );
+ out_m_norm_feature_importance =
+ out_m_feature_importance / as.scalar(rowSums(out_m_feature_importance))
+}
+
+# prepareKNNData:
+# Do data prepare - [-1,1] transform for continues variable
+# Argument:
+# * in_m_data input matrix as features
+prepareKNNData = function(matrix[double] in_m_data)
+ return(matrix[double] out_m_data)
+{
+ m_colmax = colMaxs(in_m_data);
+ m_colmin = colMins(in_m_data);
+ out_m_data = 2 * (in_m_data - m_colmin ) / ( m_colmax - m_colmin ) - 1;
+}
+
+getKNeighbor = function(matrix[double] in_m_data,
+ matrix[double] in_m_test_data,
+ matrix[double] in_m_cl,
+ integer in_i_k_max)
+ return (matrix[double] out_m_neighbor_value)
+{
+ # to naive
+ m_search_result = naiveKNNsearchForPredict(in_m_data, in_m_test_data,
in_m_cl, in_i_k_max + 1)
+ out_m_neighbor_value = m_search_result[ , 2 : in_i_k_max + 1]
+}
+
+# getSelectedKBase:
+# k selection algorithm with simple KNN algorithm.
+# Argument:
+# * in_m_data input matrix as features
+# * in_m_data_target input matrix as target value
+# * in_i_is_categorical 1 = category , 0 = continuous
+# * k_min minimum k
+# * k_max maximum k
+# * interval block size
+#
+# Return:
+# * k output k value for k selection
+getSelectedKBase = function(matrix[double] in_m_data,
+ matrix[double] in_m_data_target,
+ integer in_i_is_categorical, # 1 = category, 0 = continuous
+ integer k_min,
+ integer k_max,
+ integer interval)
+ return (integer k)
+{
+ b_continue_loop = TRUE;
+ i_iter = 1;
+ i_record = nrow(in_m_data);
+
+ i_active_model_number = k_max - k_min + 1;
+ m_active_flag = matrix(0, 1, i_active_model_number);
+
+ m_iter_err_sum = matrix(0, 1, k_max - k_min + 1);
+ m_iter_err_sum_squared = matrix(0, 1, k_max - k_min + 1);
+ while(b_continue_loop)
+ {
+ # 1.build k-d tree? , or use hash method
+ # 2.search data to get k_max nearest neighbor
+ i_process_item = i_iter * interval;
+ if(i_process_item >= i_record) {
+ i_process_item = i_record;
+ b_continue_loop = FALSE;
+ }
+ i_process_begin_item = ((i_iter - 1) * interval) + 1;
+ i_process_end_item = i_process_item;
+
+ m_neighbor_value = getKNeighbor(in_m_data, in_m_data[i_process_begin_item
: i_process_end_item, ], in_m_data_target, k_max);
+ # 3.get matrix of err from k_min to k_max
+ m_err = getErr_k(m_neighbor_value, in_m_data_target[i_process_begin_item :
i_process_end_item, ], in_i_is_categorical, k_min);
+
+ # 4.check this matrix to drop unnessary record
+ m_active_flag_tmp = matrix(0, 1, ncol(m_err));
+
+ s_rows_number = i_process_item;
+
+ m_iter_err_sum = colSums(m_err) + m_iter_err_sum;
+ m_iter_err_sum_squared = colSums(m_err ^ 2) + m_iter_err_sum_squared;
+
+ m_err_mean = - outer(t(m_iter_err_sum), m_iter_err_sum , "-") /
s_rows_number;
+ m_err_vars = ( m_err_mean ^2 * s_rows_number -
+ 2 * m_err_mean * m_iter_err_sum + m_iter_err_sum_squared) /
(s_rows_number-1);
+ m_err_vars = sqrt(m_err_vars);
+
+ parfor(i in 1 : ncol(m_err), check = 0) {
+ parfor(j in 1 : ncol(m_err), check = 0) {
+ b_execute_block = !(j == i
+ | as.scalar(m_active_flag_tmp[1, i]) == 1 # i has dropped, ignore
this case
+ | as.scalar(m_active_flag_tmp[1, j]) == 1) # j has dropped, ignore
this case
+ if(b_execute_block) {
+ b_flag = eliminateModel(as.scalar(m_err_mean[i, j]),
as.scalar(m_err_vars[i, j]), s_rows_number);
+ if(b_flag == TRUE)
+ m_active_flag_tmp[1, i] = 1;
+ }
+ }
+ }
+
+ m_active_flag = ((m_active_flag + m_active_flag_tmp) >= 1);
+ i_active_model_number = -sum(m_active_flag - 1);
+
+ # 5.break while check
+ if(i_active_model_number <= 1)
+ b_continue_loop = FALSE;
+
+ i_iter = i_iter + 1;
+ print("i_iter" + i_iter)
+ }
+
+ k = 0;
+ if(i_active_model_number == 0) {
+ print("All k kick out, use min of range " + k_min);
+ k = k_min;
+ }
+ else if(i_active_model_number == 1) {
+ k = k_min + as.integer(as.scalar(rowIndexMin(m_active_flag))) - 1;
+ print( "Get k, which value is " + k );
+ }
+ else {
+ m_err_for_order =
+ cbind(t(m_iter_err_sum), matrix(seq(k_min, k_max, 1), k_max - k_min + 1,
1));
+ m_err_for_order = removeEmpty(
+ target = m_err_for_order * t(m_active_flag == 0), margin = "rows");
+ for(i in 1 : nrow(m_err_for_order)) {
+ print("k:" + as.scalar(m_err_for_order[i, 2]) +
+ ", err:" + as.scalar(m_err_for_order[i, 1]));
+ }
+ m_err_order = order(target = m_err_for_order, by = 1, decreasing = FALSE,
index.return = FALSE);
+ k = as.integer(as.scalar(m_err_order[1, 2]));
+ print("Get minimum LOOCV error, which value is " + k);
+ }
+}
+
+# getSelectedFeature:
+# feature selection algorithm.
+# Refer to ADD part "7.1 FS+BRACE"
+# Argument:
+# in_m_data input matrix as features
+# in_m_data_target input matrix as target value
+# in_i_is_categorical 1 = category , 0 = continuous
+# in_m_init_selected S.user, initial selected feature which use
specified
+# in_i_max_select J, max feature selected
+# k_value k
+# interval block size for BRACE algorithm
+#
+# Return:
+# out_m_selected_feature output matrix for feature selection
+# out_d_min_LOOCV output err
+
+getSelectedFeature = function (
+ matrix[double] in_m_data,
+ matrix[double] in_m_data_target,
+ integer in_i_is_categorical, # 1 = category , 0 = continuous
+ matrix[double] in_m_init_selected,
+ integer in_i_max_select,
+ integer k_value,
+ integer interval )
+return(
+ matrix[double] out_m_selected_feature,
+ double out_d_min_LOOCV
+ )
+{
+ i_n_record = nrow( in_m_data );
+ i_n_column = ncol( in_m_data );
+ m_main_selected_flag = in_m_init_selected;
+ b_no_feature_selected = TRUE;
+ if( sum( in_m_init_selected ) >= 1 )
+ b_no_feature_selected = FALSE;
+
+ d_max_err_value = ( max( in_m_data_target ) - min( in_m_data_target ) ) *
100;
+ b_continue_main_loop = TRUE; #level 1 while loop flag
+ d_min_LOOCV = Inf;
+ while( b_continue_main_loop ){
+ m_feature_selected_flag = m_main_selected_flag;
+ m_this_model_selected_flag = TRUE;
+ i_index_min_LOOCV = -1; # flag for which model win in BRACE algorith
+ b_selected_morethan_one = FALSE;
+ b_continue_loop = TRUE; #level 2 while loop flag
+ i_iter = 1;
+ m_iter_err_sum = matrix( 0,1,i_n_column+1 );
+ m_iter_err_sum_squared = matrix( 0,1,i_n_column+1 );
+ while( b_continue_loop ){
+ i_process_item = i_iter*interval;
+ if( i_process_item >= i_n_record ){
+ i_process_item = i_n_record;
+ b_continue_loop = FALSE;
+ }
+ i_process_begin_item = (i_iter - 1)*interval + 1
+ i_process_end_item = i_process_item
+ m_err = matrix( d_max_err_value,i_process_end_item -
i_process_begin_item + 1,i_n_column+1 );
+ if( b_no_feature_selected == TRUE ){
+ parfor( i in 1:i_n_column ,check=0){
+ if( as.scalar( m_feature_selected_flag[1,i] ) != 1 ){
+ m_tmp_process_data = in_m_data[,i];
+ m_neighbor_value = getKNeighbor(m_tmp_process_data,
+ m_tmp_process_data[i_process_begin_item:i_process_end_item,],
in_m_data_target,k_value );
+ m_tmp_err = getErr(m_neighbor_value,
+ in_m_data_target[i_process_begin_item:i_process_end_item,],
in_i_is_categorical );
+ m_err[,i] = m_tmp_err;
+ }
+ }
+ }else{
+ #Use m_main_selected_flag but not m_feature_selected_flag,
+ # m_main_selected_flag: which feature are init selected
+ # m_feature_selected_flag: which feature are dropped & init selected
+ m_tmp_data = removeEmpty( target=in_m_data ,margin="cols", select =
m_main_selected_flag);
+ if( m_this_model_selected_flag == TRUE ){
+ m_neighbor_value = getKNeighbor(
+
m_tmp_data,m_tmp_data[i_process_begin_item:i_process_end_item,],in_m_data_target,
k_value );
+ m_tmp_err = getErr(
m_neighbor_value,in_m_data_target[i_process_begin_item:i_process_end_item,],in_i_is_categorical
);
+ m_err[,i_n_column+1] = m_tmp_err;
+ }
+ parfor( i in 1:i_n_column ,check=0){
+ if( as.scalar( m_feature_selected_flag[1,i] ) != 1 ){
+ m_tmp_process_data = cbind( m_tmp_data,in_m_data[,i] );
+ m_neighbor_value = getKNeighbor(
+
m_tmp_process_data,m_tmp_process_data[i_process_begin_item:i_process_end_item,],in_m_data_target,k_value
);
+ m_tmp_err = getErr(
+
m_neighbor_value,in_m_data_target[i_process_begin_item:i_process_end_item,],in_i_is_categorical
);
+ m_err[,i] = m_tmp_err;
+ }
+ }
+ }
+ if( m_this_model_selected_flag == TRUE )
+ m_active_flag_tmp = cbind( m_feature_selected_flag,matrix( 0,1,1 ) );
+ else
+ m_active_flag_tmp = cbind( m_feature_selected_flag,matrix( 1,1,1 ) );
+ s_rows_number = i_process_item
+ m_iter_err_sum = colSums(m_err) + m_iter_err_sum
+ m_iter_err_sum_squared = colSums(m_err ^ 2) + m_iter_err_sum_squared
+ m_err_mean = - outer(t(m_iter_err_sum), m_iter_err_sum , "-") /
s_rows_number
+ m_err_vars = ( m_err_mean ^2 * s_rows_number -
+ 2 * m_err_mean * m_iter_err_sum + m_iter_err_sum_squared) /
(s_rows_number-1)
+ m_err_vars = sqrt(m_err_vars)
+ parfor( i in 1:ncol( m_err ) ){
+ parfor( j in 1:ncol( m_err ) ,check=0){
+ b_execute_block = TRUE;
+ if( j==i ) b_execute_block = FALSE;
+ if( as.scalar( m_active_flag_tmp[1,i] ) == 1 ) b_execute_block =
FALSE;
+ #i has dropped, ignore this case
+ if( as.scalar( m_active_flag_tmp[1,j] ) == 1 ) b_execute_block =
FALSE;
+ #j has dropped, ignore this case
+ if( b_execute_block ){
+ b_flag = eliminateModel(
as.scalar(m_err_mean[i,j]),as.scalar(m_err_vars[i,j]),s_rows_number);
+ if( b_flag == TRUE )
+ m_active_flag_tmp[1,i] = 1;
+ }
+ }
+ }
+ #We mark bit to 1 for selected feature before current loop,
+ #and mark bit to 1 also for dropped feature in current loop
+ if( sum( m_active_flag_tmp != 1 ) > 1 )
+ b_selected_morethan_one = TRUE;
+ m_col_sums_err = m_iter_err_sum #colSums( m_err );
+ i_index_min_LOOCV = as.scalar( rowIndexMin( m_col_sums_err ) );
+ d_min_LOOCV = as.scalar( m_col_sums_err[1,i_index_min_LOOCV] );
+ i_index_min_LOOCV = i_index_min_LOOCV%% ( i_n_column+1 )
+ if( sum( m_active_flag_tmp != 1 ) <= 1 )
+ b_continue_loop = FALSE;
+ if( as.scalar( m_active_flag_tmp[1,i_n_column+1] ) == 1 )
+ m_this_model_selected_flag = FALSE;
+ m_feature_selected_flag = m_active_flag_tmp[,1:i_n_column];
+ i_iter = i_iter + 1;
+ }
+ #select current model, jump out.
+ if( i_index_min_LOOCV == 0 ){
+ b_continue_main_loop = FALSE;
+ print( "Select Current model" );
+ }else{
+ print( "select feature " + i_index_min_LOOCV + ", change bit value to 1"
);
+ m_main_selected_flag[1,i_index_min_LOOCV] = 1;
+ b_no_feature_selected = FALSE;
+ }
+ if( sum( m_main_selected_flag - in_m_init_selected ) >= in_i_max_select ){
+ #select more than 10
+ b_continue_main_loop = FALSE;
+ }
+ if( sum( m_main_selected_flag ) == i_n_column ){
+ #all selected
+ b_continue_main_loop = FALSE;
+ }
+ }
+ out_m_selected_feature = m_main_selected_flag;
+ out_d_min_LOOCV = d_min_LOOCV;
+}
diff --git a/scripts/builtin/knnbf.dml b/scripts/builtin/knnbf.dml
new file mode 100644
index 0000000..1146680
--- /dev/null
+++ b/scripts/builtin/knnbf.dml
@@ -0,0 +1,58 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+m_knnbf = function(
+ Matrix[Double] X,
+ Matrix[Double] T,
+ Integer k_value = 5
+ ) return(
+ Matrix[Double] NNR
+ )
+{
+ num_records = nrow(X);
+ num_queries = nrow(T);
+
+ D = matrix(0, rows = num_records, cols = num_queries);
+ NNR = matrix(0, rows = num_queries, cols = k_value);
+
+ parfor(i in 1 : num_queries) {
+ D[ , i] = calculateDistance(X, T[i, ]);
+ NNR[i, ] = sortAndGetK(D[ , i], k_value);
+ }
+}
+
+calculateDistance = function(Matrix[Double] R, Matrix[Double] Q)
+ return(Matrix[Double] distances)
+{
+ NR = rowSums(R ^ 2) %*% matrix(1,1,nrow(Q));
+ NQ = matrix(1,nrow(R),1) %*% t(rowSums(Q ^ 2));
+ distances = NR + NQ - 2.0 * R %*% t(Q);
+}
+
+sortAndGetK = function(Matrix[Double] D, Integer k)
+ return (Matrix[Double] knn_)
+{
+ if(nrow(D) < k)
+ stop("can not pick "+k+" nearest neighbours from "+nrow(D)+" total
instances")
+
+ sort_dist = order(target = D, by = 1, decreasing= FALSE, index.return =
TRUE)
+ knn_ = t(sort_dist[1:k,])
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index ad9f141..9080136 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -33,7 +33,7 @@ import org.apache.sysds.common.Types.ReturnType;
* builtin functions.
*
* To add a new builtin script function, simply add the definition here
- * as well as a dml file in scripts/builtin with a matching name. On
+ * as well as a dml file in scripts/builtin with a matching name. On
* building SystemDS, these scripts are packaged into the jar as well.
*/
public enum Builtins {
@@ -136,6 +136,8 @@ public enum Builtins {
ISINF("is.infinite", false),
KMEANS("kmeans", true),
KMEANSPREDICT("kmeansPredict", true),
+ KNNBF("knnbf", true),
+ KNN("knn", true),
L2SVM("l2svm", true),
LASSO("lasso", true),
LENGTH("length", false),
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 0e4883c..0d541e7 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -91,7 +91,7 @@ import org.junit.Assert;
* <li>clean up</li>
* </ul>
*/
-public class TestUtils
+public class TestUtils
{
private static final Log LOG =
LogFactory.getLog(TestUtils.class.getName());
@@ -112,14 +112,14 @@ public class TestUtils
try {
String lineExpected = null;
String lineActual = null;
-
+
Path compareFile = new Path(expectedFile);
FileSystem fs =
IOUtilFunctions.getFileSystem(compareFile, conf);
FSDataInputStream fsin = fs.open(compareFile);
try( BufferedReader compareIn = new BufferedReader(new
InputStreamReader(fsin)) ) {
lineExpected = compareIn.readLine();
}
-
+
Path outFile = new Path(actualFile);
FSDataInputStream fsout = fs.open(outFile);
try( BufferedReader outIn = new BufferedReader(new
InputStreamReader(fsout)) ) {
@@ -132,7 +132,7 @@ public class TestUtils
fail("unable to read file: " + e.getMessage());
}
}
-
+
/**
* Compares contents of an expected file with the actual file, where
rows may be permuted
* @param expectedFile
@@ -144,7 +144,7 @@ public class TestUtils
{
try {
HashMap<CellIndex, Double> expectedValues = new
HashMap<>();
-
+
Path outDirectory = new Path(actualDir);
Path compareFile = new Path(expectedFile);
FileSystem fs =
IOUtilFunctions.getFileSystem(outDirectory, conf);
@@ -166,35 +166,35 @@ public class TestUtils
if(expectedValue != 0.0)
e_list.add(expectedValue);
}
-
+
ArrayList<Double> a_list = new ArrayList<>();
for (CellIndex index : actualValues.keySet()) {
Double actualValue = actualValues.get(index);
if(actualValue != 0.0)
a_list.add(actualValue);
}
-
+
Collections.sort(e_list);
Collections.sort(a_list);
-
+
assertTrue("Matrix nzs not equal", e_list.size() ==
a_list.size());
for(int i=0; i < e_list.size(); i++)
{
assertTrue("Matrix values not equals",
Math.abs(e_list.get(i) - a_list.get(i)) <= epsilon);
}
-
+
} catch (IOException e) {
fail("unable to read file: " + e.getMessage());
}
}
-
+
/**
* <p>
* Compares the expected values calculated in Java by testcase and
which are
* in the normal filesystem, with those calculated by SystemDS located
in
* HDFS with Matrix Market format
* </p>
- *
+ *
* @param expectedFile
* file with expected values, which is located in OS
filesystem
* @param actualDir
@@ -209,33 +209,33 @@ public class TestUtils
Path compareFile = new Path(expectedFile);
FileSystem fs =
IOUtilFunctions.getFileSystem(outDirectory, conf);
FSDataInputStream fsin = fs.open(compareFile);
-
+
HashMap<CellIndex, Double> expectedValues = new
HashMap<>();
String[] expRcn = null;
-
+
try(BufferedReader compareIn = new BufferedReader(new
InputStreamReader(fsin)) ) {
// skip the header of Matrix Market file
String line = compareIn.readLine();
-
+
// rows, cols and nnz
line = compareIn.readLine();
expRcn = line.split(" ");
-
+
readValuesFromFileStreamAndPut(compareIn,
expectedValues);
}
-
+
HashMap<CellIndex, Double> actualValues = new
HashMap<>();
FSDataInputStream fsout = fs.open(outDirectory);
try( BufferedReader outIn = new BufferedReader(new
InputStreamReader(fsout)) ) {
-
+
//skip MM header
String line = outIn.readLine();
-
+
//rows, cols and nnz
line = outIn.readLine();
String[] rcn = line.split(" ");
-
+
if (Integer.parseInt(expRcn[0]) !=
Integer.parseInt(rcn[0])) {
LOG.warn(" Rows mismatch: expected " +
Integer.parseInt(expRcn[0]) + ", actual " + Integer.parseInt(rcn[0]));
}
@@ -273,12 +273,12 @@ public class TestUtils
fail("unable to read file: " + e.getMessage());
}
}
-
+
/**
- * Read doubles from the input stream and put them into the given
hashmap of values.
+ * Read doubles from the input stream and put them into the given
hashmap of values.
* @param inputStream input stream of doubles with related indices
* @param values hashmap of values (initially empty)
- * @throws IOException
+ * @throws IOException
*/
public static void readValuesFromFileStream(FSDataInputStream
inputStream, HashMap<CellIndex, Double> values)
throws IOException
@@ -293,7 +293,7 @@ public class TestUtils
* @param inReader BufferedReader to read values from
* @param values hashmap where values are put
*/
- public static void readValuesFromFileStreamAndPut(BufferedReader
inReader, HashMap<CellIndex, Double> values)
+ public static void readValuesFromFileStreamAndPut(BufferedReader
inReader, HashMap<CellIndex, Double> values)
throws IOException
{
String line = null;
@@ -359,14 +359,14 @@ public class TestUtils
fail("unable to read file: " + e.getMessage());
}
}
-
+
/**
* <p>
* Compares the expected values calculated in Java by testcase and
which are
* in the normal filesystem, with those calculated by SystemDS located
in
* HDFS
* </p>
- *
+ *
* @param expectedFile
* file with expected values, which is located in OS
filesystem
* @param actualDir
@@ -402,7 +402,7 @@ public class TestUtils
}
assertEquals("for file " + actualDir + " " + countErrors + "
values are not equal", 0, countErrors);
}
-
+
/**
* <p>
* Compares the expected values calculated in Java by testcase and
which are
@@ -440,7 +440,7 @@ public class TestUtils
}
assertEquals("for file " + actualDir + " " + countErrors + "
values are not equal", 0, countErrors);
}
-
+
public static void compareTensorBlocks(TensorBlock tb1, TensorBlock
tb2) {
Assert.assertEquals(tb1.getValueType(), tb2.getValueType());
Assert.assertArrayEquals(tb1.getSchema(), tb2.getSchema());
@@ -450,12 +450,12 @@ public class TestUtils
for (int j = 0; j < tb1.getNumColumns(); j++)
Assert.assertEquals(tb1.get(new int[]{i, j}),
tb2.get(new int[]{i, j}));
}
-
+
public static TensorBlock createBasicTensor(ValueType vt, int rows, int
cols, double sparsity) {
return DataConverter.convertToTensorBlock(TestUtils.round(
MatrixBlock.randOperations(rows, cols, sparsity, 0, 10,
"uniform", 7)), vt, true);
}
-
+
public static TensorBlock createDataTensor(ValueType vt, int rows, int
cols, double sparsity) {
return DataConverter.convertToTensorBlock(TestUtils.round(
MatrixBlock.randOperations(rows, cols, sparsity, 0, 10,
"uniform", 7)), vt, false);
@@ -470,11 +470,11 @@ public class TestUtils
* @param filePath Path to the file to be read.
* @return Matrix values in a hashmap <index,value>
*/
- public static HashMap<CellIndex, Double> readDMLMatrixFromHDFS(String
filePath)
+ public static HashMap<CellIndex, Double> readDMLMatrixFromHDFS(String
filePath)
{
HashMap<CellIndex, Double> expectedValues = new HashMap<>();
-
- try
+
+ try
{
Path outDirectory = new Path(filePath);
FileSystem fs =
IOUtilFunctions.getFileSystem(outDirectory, conf);
@@ -484,7 +484,7 @@ public class TestUtils
FSDataInputStream outIn =
fs.open(file.getPath());
readValuesFromFileStream(outIn, expectedValues);
}
- }
+ }
catch (IOException e) {
assertTrue("could not read from file " + filePath+":
"+e.getMessage(), false);
}
@@ -501,26 +501,26 @@ public class TestUtils
* @param filePath Path to the file to be read.
* @return Matrix values in a hashmap <index,value>
*/
- public static HashMap<CellIndex, Double> readRMatrixFromFS(String
filePath)
+ public static HashMap<CellIndex, Double> readRMatrixFromFS(String
filePath)
{
HashMap<CellIndex, Double> expectedValues = new HashMap<>();
-
- try(BufferedReader reader = new BufferedReader(new
FileReader(filePath)))
+
+ try(BufferedReader reader = new BufferedReader(new
FileReader(filePath)))
{
// skip both R header lines
String line = reader.readLine();
-
+
int matrixType = -1;
if ( line.endsWith(" general") )
matrixType = 1;
if ( line.endsWith(" symmetric") )
matrixType = 2;
-
+
if ( matrixType == -1 )
throw new RuntimeException("unknown matrix type
while reading R matrix: " + line);
-
+
line = reader.readLine(); // header line with dimension
and nnz information
-
+
while ((line = reader.readLine()) != null) {
StringTokenizer st = new StringTokenizer(line,
" ");
int i = Integer.parseInt(st.nextToken());
@@ -538,14 +538,14 @@ public class TestUtils
expectedValues.put(new
CellIndex(j, i), 1.0);
}
}
- }
+ }
catch (IOException e) {
assertTrue("could not read from file " + filePath,
false);
}
-
+
return expectedValues;
}
-
+
/**
* Reads a scalar value in DML format from HDFS
*/
@@ -598,7 +598,7 @@ public class TestUtils
}
return _AssertOccured;
}
-
+
public static String readDMLString(String filePath) {
try {
StringBuilder sb = new StringBuilder();
@@ -617,8 +617,8 @@ public class TestUtils
}
return null;
}
-
-
+
+
/**
* Reads a scalar value in R format from OS's FS
*/
@@ -627,7 +627,7 @@ public class TestUtils
expectedValues.put(new CellIndex(1,1), readRScalar(filePath));
return expectedValues;
}
-
+
public static Double readRScalar(String filePath) {
try {
double d = Double.NaN;
@@ -643,12 +643,12 @@ public class TestUtils
}
return Double.NaN;
}
-
+
public static String processMultiPartCSVForR(String csvFile) throws
IOException {
File csv = new File(csvFile);
if (csv.isDirectory()) {
File[] parts = csv.listFiles();
-
+
int count=0;
int index = -1;
for(int i=0; i < parts.length; i++ ) {
@@ -659,7 +659,7 @@ public class TestUtils
count++;
index = i;
}
-
+
if ( count == 1) {
csvFile = parts[index].toString();
}
@@ -686,7 +686,7 @@ public class TestUtils
out.append(fileContents);
}
}
-
+
csvFile = tmp.getCanonicalPath();
}
else {
@@ -699,7 +699,7 @@ public class TestUtils
/**
* Compares two double values regarding tolerance t. If one or both of
them
* is null it is converted to 0.0.
- *
+ *
* @param v1
* @param v2
* @param t Tolerance
@@ -722,13 +722,13 @@ public class TestUtils
return Math.abs(v1 - v2) <= t;
}
-
+
public static void compareMatrices(double[] expectedMatrix, double[]
actualMatrix, double epsilon) {
- compareMatrices(new double[][]{expectedMatrix},
+ compareMatrices(new double[][]{expectedMatrix},
new double[][]{actualMatrix}, 1, expectedMatrix.length,
epsilon);
}
-
-
+
+
public static void compareMatrices(double[][] expectedMatrix,
double[][] actualMatrix, int rows, int cols,
double epsilon) {
compareMatrices(expectedMatrix, actualMatrix,
expectedMatrix.length, expectedMatrix[0].length, epsilon, "");
@@ -760,7 +760,7 @@ public class TestUtils
assertEqualColsAndRows(expectedMatrix,actualMatrix);
compareMatrices(expectedMatrix, actualMatrix,
expectedMatrix.length, expectedMatrix[0].length, epsilon, message);
}
-
+
public static void compareFrames(String[][] expectedFrame, String[][]
actualFrame, int rows, int cols ) {
int countErrors = 0;
for (int i = 0; i < rows; i++) {
@@ -774,9 +774,9 @@ public class TestUtils
}
assertTrue("" + countErrors + " values are not in equal",
countErrors == 0);
}
-
+
public static void compareScalars(double d1, double d2, double tol) {
- assertTrue("Given scalars do not match: " + d1 + " != " + d2 ,
compareCellValue(d1, d2, tol, false));
+ assertTrue("Given scalars do not match: " + d1 + " != " + d2 ,
compareCellValue(d1, d2, tol, false));
}
public static void compareMatricesBit(double[][] expectedMatrix,
double[][] actualMatrix, int rows, int cols,
@@ -796,7 +796,7 @@ public class TestUtils
public static void compareMatricesBitAvgDistance(double[][]
expectedMatrix, double[][] actualMatrix,
long maxUnitsOfLeastPrecision, long maxAvgDistance,
String message){
assertEqualColsAndRows(expectedMatrix,actualMatrix);
- compareMatricesBitAvgDistance(expectedMatrix, actualMatrix,
expectedMatrix.length, actualMatrix[0].length,
+ compareMatricesBitAvgDistance(expectedMatrix, actualMatrix,
expectedMatrix.length, actualMatrix[0].length,
maxUnitsOfLeastPrecision, maxAvgDistance, message);
}
@@ -853,24 +853,24 @@ public class TestUtils
}
private static void assertEqualColsAndRows(double[][] expectedMatrix,
double[][] actualMatrix){
- assertTrue("The number of columns in the matrixes should be
equal :"
+ assertTrue("The number of columns in the matrixes should be
equal :"
+ expectedMatrix.length + " "
- + actualMatrix.length,
+ + actualMatrix.length,
expectedMatrix.length == actualMatrix.length);
- assertTrue("The number of rows in the matrixes should be equal"
- + expectedMatrix[0].length + " "
- + actualMatrix[0].length,
+ assertTrue("The number of rows in the matrixes should be equal"
+ + expectedMatrix[0].length + " "
+ + actualMatrix[0].length,
expectedMatrix[0].length == actualMatrix[0].length);
}
- public static void compareMatricesPercentageDistance(double[][]
expectedMatrix, double[][] actualMatrix,
+ public static void compareMatricesPercentageDistance(double[][]
expectedMatrix, double[][] actualMatrix,
double percentDistanceAllowed, double
maxAveragePercentDistance, String message){
assertEqualColsAndRows(expectedMatrix,actualMatrix);
compareMatricesPercentageDistance(expectedMatrix, actualMatrix,
expectedMatrix.length, expectedMatrix[0].length,
percentDistanceAllowed, maxAveragePercentDistance,
message, false);
}
- public static void compareMatricesPercentageDistance(double[][]
expectedMatrix, double[][] actualMatrix,
+ public static void compareMatricesPercentageDistance(double[][]
expectedMatrix, double[][] actualMatrix,
double percentDistanceAllowed, double
maxAveragePercentDistance, String message, boolean ignoreZero){
assertEqualColsAndRows(expectedMatrix,actualMatrix);
compareMatricesPercentageDistance(expectedMatrix, actualMatrix,
expectedMatrix.length, expectedMatrix[0].length,
@@ -907,6 +907,29 @@ public class TestUtils
}
}
+ public static void compareMatricesAvgRowDistance(double[][]
expectedMatrix, double[][] actualMatrix, int rows,
+ int cols, double averageDistanceAllowed){
+ String message = "";
+ int countErrors = 0;
+
+ for (int i = 0; i < rows && countErrors < 20; i++) {
+ double distanceSum = 0;
+ for (int j = 0; j < cols && countErrors < 20;
j++) {
+ distanceSum += expectedMatrix[i][j] -
actualMatrix[i][j];
+ }
+ if(distanceSum / cols > averageDistanceAllowed){
+ message += ("Average distance for row "
+ i + ":" + (distanceSum / cols) + "\n");
+ countErrors++;
+ }
+ }
+ if(countErrors == 20){
+ assertTrue(message + "\n At least 20 values are
not in equal", countErrors == 0);
+ }
+ else{
+ assertTrue(message + "\n" + countErrors + "
values are not in equal of total: " + (rows), countErrors == 0);
+ }
+ }
+
public static void compareMatricesBitAvgDistance(double[][]
expectedMatrix, double[][] actualMatrix, int rows,
int cols, long maxUnitsOfLeastPrecision, long maxAvgDistance) {
compareMatricesBitAvgDistance(expectedMatrix,
actualMatrix, rows, cols, maxUnitsOfLeastPrecision, maxAvgDistance, "");
@@ -914,24 +937,24 @@ public class TestUtils
/**
* Compare two double precision floats for equality within a margin of
error.
- *
+ *
* This can be used to compensate for inequality caused by accumulated
* floating point math errors.
- *
+ *
* The error margin is specified in ULPs (units of least precision).
* A one-ULP difference means there are no representable floats in
between.
* E.g. 0f and 1.4e-45f are one ULP apart. So are -6.1340704f and
-6.13407f.
* Depending on the number of calculations involved, typically a margin
of
* 1-5 ULPs should be enough.
- *
+ *
* @param d1 The expected value.
* @param d2 The actual value.
* @return Whether distance in bits
*/
public static long compareScalarBits(double d1, double d2) {
-
+
// assertTrue("Both values should be positive or negative",(d1
>= 0 && d2 >= 0) || (d2 <= 0 && d1 <= 0));
-
+
long expectedBits = Double.doubleToLongBits(d1) < 0 ?
0x8000000000000000L - Double.doubleToLongBits(d1) : Double.doubleToLongBits(d1);
long actualBits = Double.doubleToLongBits(d2) < 0 ?
0x8000000000000000L - Double.doubleToLongBits(d2) : Double.doubleToLongBits(d2);
long difference = expectedBits > actualBits ? expectedBits -
actualBits : actualBits - expectedBits;
@@ -954,29 +977,29 @@ public class TestUtils
long distance = compareScalarBits(d1,d2);
assertTrue("Given scalars do not match: " + d1 + " != " + d2 +
" with bitDistance: " + distance ,distance <= maxUnitsOfLeastPrecision);
}
-
+
public static void compareScalars(String expected, String actual) {
assertEquals(expected, actual);
}
public static boolean compareMatrices(HashMap<CellIndex, Double> m1,
HashMap<CellIndex, Double> m2,
- double tolerance, String name1, String name2)
+ double tolerance, String name1, String name2)
{
return compareMatrices(m1, m2, tolerance, name1, name2, false);
}
-
+
public static void compareMatrices(HashMap<CellIndex, Double> m1,
MatrixBlock m2, double tolerance) {
double[][] ret1 = convertHashMapToDoubleArray(m1);
double[][] ret2 = DataConverter.convertToDoubleMatrix(m2);
compareMatrices(ret1, ret2, m2.getNumRows(),
m2.getNumColumns(), tolerance);
}
-
+
public static void compareMatrices(MatrixBlock m1, MatrixBlock m2,
double tolerance) {
double[][] ret1 = DataConverter.convertToDoubleMatrix(m1);
double[][] ret2 = DataConverter.convertToDoubleMatrix(m2);
compareMatrices(ret1, ret2, m2.getNumRows(),
m2.getNumColumns(), tolerance);
}
-
+
/**
* Compares two matrices given as HashMaps. The matrix containing more
nnz
* is iterated and each cell value compared against the corresponding
cell
@@ -984,7 +1007,7 @@ public class TestUtils
* This method does not assert. Instead statistics are added to
* AssertionBuffer, at the end of the test you should call
* {@link TestUtils#displayAssertionBuffer()}.
- *
+ *
* @param m1
* @param m2
* @param tolerance
@@ -997,7 +1020,7 @@ public class TestUtils
String namefirst = name2;
String namesecond = name1;
boolean flag = true;
-
+
// to ensure that always the matrix with more nnz is iterated
if (m1.size() > m2.size()) {
first = m1;
@@ -1024,7 +1047,7 @@ public class TestUtils
countErrorWithinTolerance++;
if(!flag)
System.out.println(e.getKey()+": "+v1+" <--> "+v2);
- else
+ else
System.out.println(e.getKey()+": "+v2+" <--> "+v1);
}
} else {
@@ -1049,35 +1072,35 @@ public class TestUtils
_AssertOccured = true;
return false;
}
-
-
+
+
/**
- *
+ *
* @param vt
* @param in1
* @param in2
* @param tolerance
- *
+ *
* @return
*/
public static int compareTo(ValueType vt, Object in1, Object in2,
double tolerance) {
if(in1 == null && in2 == null) return 0;
else if(in1 == null) return -1;
else if(in2 == null) return 1;
-
+
switch( vt ) {
case STRING: return
((String)in1).compareTo((String)in2);
case BOOLEAN: return
((Boolean)in1).compareTo((Boolean)in2);
case INT64: return ((Long)in1).compareTo((Long)in2);
- case FP64:
+ case FP64:
return (Math.abs((Double)in1-(Double)in2) <
tolerance)?0:
((Double)in1).compareTo((Double)in2);
default: throw new RuntimeException("Unsupported value
type: "+vt);
}
}
-
+
/**
- *
+ *
* @param vt
* @param in1
* @param inR
@@ -1087,32 +1110,32 @@ public class TestUtils
if(in1 == null && (inR == null ||
(inR.toString().compareTo("NA")==0))) return 0;
else if(in1 == null && vt == ValueType.STRING) return -1;
else if(inR == null) return 1;
-
+
switch( vt ) {
case STRING: return
((String)in1).compareTo((String)inR);
- case BOOLEAN:
+ case BOOLEAN:
if(in1 == null)
return
Boolean.FALSE.compareTo(((Boolean)inR).booleanValue());
else
return
((Boolean)in1).compareTo((Boolean)inR);
- case INT64:
+ case INT64:
if(in1 == null)
return new
Long(0).compareTo(((Long)inR));
else
return ((Long)in1).compareTo((Long)inR);
- case FP64:
+ case FP64:
if(in1 == null)
return (new
Double(0)).compareTo((Double)inR);
else
- return
(Math.abs((Double)in1-(Double)inR) < tolerance)?0:
+ return
(Math.abs((Double)in1-(Double)inR) < tolerance)?0:
((Double)in1).compareTo((Double)inR);
default: throw new RuntimeException("Unsupported value
type: "+vt);
}
}
-
+
/**
* Converts a 2D array into a sparse hashmap matrix.
- *
+ *
* @param matrix
* @return
*/
@@ -1127,7 +1150,7 @@ public class TestUtils
return hmMatrix;
}
-
+
/**
* Method to convert a hashmap of matrix entries into a double array
* @param matrix
@@ -1147,37 +1170,37 @@ public class TestUtils
max_cols = ci.column;
}
}
-
+
double [][] ret_arr = new double[max_rows][max_cols];
-
+
for(CellIndex ci:matrix.keySet())
{
int i = ci.row-1;
int j = ci.column-1;
ret_arr[i][j] = matrix.get(ci);
}
-
+
return ret_arr;
-
+
}
-
+
public static double[][] convertHashMapToDoubleArray(HashMap
<CellIndex, Double> matrix, int rows, int cols)
{
double [][] ret_arr = new double[rows][cols];
-
+
for(CellIndex ci:matrix.keySet()) {
int i = ci.row-1;
int j = ci.column-1;
ret_arr[i][j] = matrix.get(ci);
}
-
+
return ret_arr;
-
+
}
/**
* Converts a 2D double array into a 1D double array.
- *
+ *
* @param array
* @return
*/
@@ -1195,7 +1218,7 @@ public class TestUtils
/**
* Converts a 1D double array into a 2D double array.
- *
+ *
* @param array
* @return
*/
@@ -1228,7 +1251,7 @@ public class TestUtils
* Compares a dml matrix file in HDFS with a file in normal file system
* generated by R
* </p>
- *
+ *
* @param rFile
* file with values calculated by R
* @param hdfsDir
@@ -1248,7 +1271,7 @@ public class TestUtils
compareIn.readLine();
readValuesFromFileStreamAndPut(compareIn,
expectedValues);
}
-
+
FileStatus[] outFiles = fs.listStatus(outDirectory);
for (FileStatus file : outFiles) {
@@ -1282,7 +1305,7 @@ public class TestUtils
* <p>
* Checks a matrix against a number of specifications.
* </p>
- *
+ *
* @param data
* matrix data
* @param mc
@@ -1312,7 +1335,7 @@ public class TestUtils
* Checks a matrix read from a file in text format against a number of
* specifications.
* </p>
- *
+ *
* @param outDir
* directory containing the matrix
* @param rows
@@ -1329,7 +1352,7 @@ public class TestUtils
Path outDirectory = new Path(outDir);
FileSystem fs =
IOUtilFunctions.getFileSystem(outDirectory, conf);
assertTrue(outDir + " does not exist",
fs.exists(outDirectory));
-
+
if( fs.getFileStatus(outDirectory).isDirectory() )
{
FileStatus[] outFiles =
fs.listStatus(outDirectory);
@@ -1374,7 +1397,7 @@ public class TestUtils
* <p>
* Checks for matrix in directory existence.
* </p>
- *
+ *
* @param outDir
* directory
*/
@@ -1401,7 +1424,7 @@ public class TestUtils
* <p>
* Removes all the directories specified in the array in HDFS
* </p>
- *
+ *
* @param directories
* directories array
*/
@@ -1422,7 +1445,7 @@ public class TestUtils
* <p>
* Removes all the directories specified in the array in OS filesystem
* </p>
- *
+ *
* @param directories
* directories array
*/
@@ -1451,7 +1474,7 @@ public class TestUtils
* <p>
* Removes all the files specified in the array in HDFS
* </p>
- *
+ *
* @param files
* files array
*/
@@ -1472,7 +1495,7 @@ public class TestUtils
* <p>
* Removes all the files specified in the array in OS filesystem
* </p>
- *
+ *
* @param files
* files array
*/
@@ -1490,7 +1513,7 @@ public class TestUtils
* <p>
* Clears a complete directory.
* </p>
- *
+ *
* @param directory
* directory
*/
@@ -1514,7 +1537,7 @@ public class TestUtils
* <p>
* Set seed to -1 to use the current time as seed.
* </p>
- *
+ *
* @param rows
* number of rows
* @param cols
@@ -1547,7 +1570,7 @@ public class TestUtils
* Generates a test matrix with the specified parameters as a two
* dimensional array.
* Set seed to -1 to use the current time as seed.
- *
+ *
* @param rows number of rows
* @param cols number of columns
* @param min minimum value
@@ -1573,9 +1596,9 @@ public class TestUtils
}
/**
- *
+ *
* Generates a test matrix, but only containing real numbers, in the
range specified.
- *
+ *
* @param rows number of rows
* @param cols number of columns
* @param min minimum value whole number
@@ -1616,7 +1639,7 @@ public class TestUtils
* <p>
* Set seed to -1 to use the current time as seed.
* </p>
- *
+ *
* @param rows
* number of rows
* @param cols
@@ -1653,7 +1676,7 @@ public class TestUtils
* <p>
* Set seed to -1 to use the current time as seed.
* </p>
- *
+ *
* @param file
* output file
* @param rows
@@ -1677,7 +1700,7 @@ public class TestUtils
DataOutputStream out = fs.create(inFile);
try( PrintWriter pw = new PrintWriter(out) ) {
Random random = (seed == -1) ? TestUtils.random
: new Random(seed);
-
+
for (int i = 1; i <= rows; i++) {
for (int j = 1; j <= cols; j++) {
if (random.nextDouble() >
sparsity)
@@ -1880,7 +1903,7 @@ public class TestUtils
/**
* Counts the number of NNZ values in a matrix
- *
+ *
* @param matrix
* @return
*/
@@ -1895,9 +1918,9 @@ public class TestUtils
return n;
}
- public static void writeCSVTestMatrix(String file, double[][] matrix)
+ public static void writeCSVTestMatrix(String file, double[][] matrix)
{
- try
+ try
{
//create outputstream to HDFS / FS and writer
Path path = new Path(file);
@@ -1912,7 +1935,7 @@ public class TestUtils
sb.append(matrix[i][0]);
for (int j = 1; j < matrix[i].length;
j++) {
sb.append(",");
- if ( matrix[i][j] == 0 )
+ if ( matrix[i][j] == 0 )
continue;
sb.append(matrix[i][j]);
}
@@ -1920,8 +1943,8 @@ public class TestUtils
pw.append(sb.toString());
}
}
- }
- catch (IOException e)
+ }
+ catch (IOException e)
{
fail("unable to write (csv) test matrix (" + file + "):
" + e.getMessage());
}
@@ -1931,18 +1954,18 @@ public class TestUtils
* <p>
* Writes a matrix to a file using the text format.
* </p>
- *
+ *
* @param file
* file name
* @param matrix
* matrix
* @param isR
* when true, writes a R matrix to disk
- *
+ *
*/
- public static void writeTestMatrix(String file, double[][] matrix,
boolean isR)
+ public static void writeTestMatrix(String file, double[][] matrix,
boolean isR)
{
- try
+ try
{
//create outputstream to HDFS / FS and writer
DataOutputStream out = null;
@@ -1950,26 +1973,26 @@ public class TestUtils
Path path = new Path(file);
FileSystem fs =
IOUtilFunctions.getFileSystem(path, conf);
out = fs.create(path, true);
- }
+ }
else {
out = new DataOutputStream(new
FileOutputStream(file));
}
-
+
try( BufferedWriter pw = new BufferedWriter(new
OutputStreamWriter(out))) {
-
+
//write header
if( isR ) {
/** add R header */
pw.append("%%MatrixMarket matrix
coordinate real general\n");
pw.append("" + matrix.length + " " +
matrix[0].length + " " + matrix.length*matrix[0].length+"\n");
}
-
+
//writer actual matrix
StringBuilder sb = new StringBuilder();
boolean emptyOutput = true;
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length;
j++) {
- if ( matrix[i][j] == 0 )
+ if ( matrix[i][j] == 0 )
continue;
sb.append(i + 1);
sb.append(' ');
@@ -1982,13 +2005,13 @@ public class TestUtils
emptyOutput = false;
}
}
-
+
//writer dummy entry if empty
if( emptyOutput )
pw.append("1 1 " + matrix[0][0]);
}
- }
- catch (IOException e)
+ }
+ catch (IOException e)
{
fail("unable to write test matrix (" + file + "): " +
e.getMessage());
}
@@ -1998,7 +2021,7 @@ public class TestUtils
* <p>
* Writes a matrix to a file using the text format.
* </p>
- *
+ *
* @param file
* file name
* @param matrix
@@ -2008,18 +2031,18 @@ public class TestUtils
writeTestMatrix(file, matrix, false);
}
-
+
/**
* <p>
* Writes a frame to a file using the text format.
* </p>
- *
+ *
* @param file
* file name
* @param data
* frame data
* @param isR
- * @throws IOException
+ * @throws IOException
*/
public static void writeTestFrame(String file, double[][] data,
ValueType[] schema, FileFormat fmt, boolean isR) throws IOException {
FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt);
@@ -2027,17 +2050,17 @@ public class TestUtils
initFrameData(frame, data, schema, data.length);
writer.writeFrameToHDFS(frame, file, data.length,
schema.length);
}
-
+
/**
* <p>
* Writes a frame to a file using the text format.
* </p>
- *
+ *
* @param file
* file name
* @param data
* frame data
- * @throws IOException
+ * @throws IOException
*/
public static void writeTestFrame(String file, double[][] data,
ValueType[] schema, FileFormat fmt) throws IOException {
writeTestFrame(file, data, schema, fmt, false);
@@ -2047,7 +2070,7 @@ public class TestUtils
Object[] row1 = new Object[lschema.length];
for( int i=0; i<rows; i++ ) {
for( int j=0; j<lschema.length; j++ ) {
- data[i][j] =
UtilFunctions.objectToDouble(lschema[j],
+ data[i][j] =
UtilFunctions.objectToDouble(lschema[j],
row1[j] =
UtilFunctions.doubleToObject(lschema[j], data[i][j]));
if(row1[j] != null && lschema[j] ==
ValueType.STRING)
row1[j] = "Str" + row1[j];
@@ -2056,7 +2079,7 @@ public class TestUtils
}
}
-
+
/* Write a scalar value to a file */
public static void writeTestScalar(String file, double value) {
try {
@@ -2079,12 +2102,12 @@ public class TestUtils
fail("unable to write test scalar (" + file + "): " +
e.getMessage());
}
}
-
+
/**
* <p>
* Writes a matrix to a file using the binary cells format.
* </p>
- *
+ *
* @param file
* file name
* @param matrix
@@ -2125,7 +2148,7 @@ public class TestUtils
* <p>
* Writes a matrix to a file using the binary blocks format.
* </p>
- *
+ *
* @param file
* file name
* @param matrix
@@ -2140,7 +2163,7 @@ public class TestUtils
public static void writeBinaryTestMatrixBlocks(String file, double[][]
matrix, int rowsInBlock, int colsInBlock,
boolean sparseFormat) {
SequenceFile.Writer writer = null;
-
+
try {
Path path = new Path(file);
Writer.Option filePath = Writer.file(path);
@@ -2164,7 +2187,7 @@ public class TestUtils
writer.append(index, value);
}
}
- }
+ }
catch (IOException e) {
e.printStackTrace();
fail("unable to write test matrix: " + e.getMessage());
@@ -2178,7 +2201,7 @@ public class TestUtils
* <p>
* Prints out a DML script.
* </p>
- *
+ *
* @param dmlScriptFile
* filename of DML script
*/
@@ -2202,7 +2225,7 @@ public class TestUtils
* <p>
* Prints out a PYDML script.
* </p>
- *
+ *
* @param pydmlScriptFile
* filename of PYDML script
*/
@@ -2214,19 +2237,19 @@ public class TestUtils
while ((content = in.readLine()) != null) {
System.out.println(content);
}
- }
+ }
catch (IOException e) {
e.printStackTrace();
fail("unable to print pydml script: " + e.getMessage());
}
System.out.println("**************************************************\n\n");
}
-
+
/**
* <p>
* Prints out an R script.
* </p>
- *
+ *
* @param dmlScriptFile
* filename of RL script
*/
@@ -2250,7 +2273,7 @@ public class TestUtils
* <p>
* Renames a temporary DML script file back to it's original name.
* </p>
- *
+ *
* @param dmlScriptFile
* temporary script file
*/
@@ -2288,7 +2311,7 @@ public class TestUtils
* Checks if any temporary files or directories exist in the current
working
* directory.
* </p>
- *
+ *
* @return true if temporary files or directories are available
*/
@SuppressWarnings("resource")
@@ -2316,7 +2339,7 @@ public class TestUtils
* Returns the path to a file in a directory if it is the only file in
the
* directory.
* </p>
- *
+ *
* @param directory
* directory containing the file
* @return path of the file
@@ -2342,7 +2365,7 @@ public class TestUtils
* <p>
* Creates an empty file.
* </p>
- *
+ *
* @param filename
* filename
*/
@@ -2356,7 +2379,7 @@ public class TestUtils
* <p>
* Performs transpose onto a matrix and returns the result.
* </p>
- *
+ *
* @param a
* matrix
* @return transposed matrix
@@ -2379,7 +2402,7 @@ public class TestUtils
* <p>
* Performs matrix multiplication onto two matrices and returns the
result.
* </p>
- *
+ *
* @param a
* left matrix
* @param b
@@ -2409,7 +2432,7 @@ public class TestUtils
* <p>
* Returns a random integer value.
* </p>
- *
+ *
* @return random integer value
*/
public static int getRandomInt() {
@@ -2422,7 +2445,7 @@ public class TestUtils
* <p>
* Returns a positive random integer value.
* </p>
- *
+ *
* @return positive random integer value
*/
public static int getPositiveRandomInt() {
@@ -2436,7 +2459,7 @@ public class TestUtils
* <p>
* Returns a negative random integer value.
* </p>
- *
+ *
* @return negative random integer value
*/
public static int getNegativeRandomInt() {
@@ -2450,7 +2473,7 @@ public class TestUtils
* <p>
* Returns a random double value.
* </p>
- *
+ *
* @return random double value
*/
public static double getRandomDouble() {
@@ -2463,7 +2486,7 @@ public class TestUtils
* <p>
* Returns a positive random double value.
* </p>
- *
+ *
* @return positive random double value
*/
public static double getPositiveRandomDouble() {
@@ -2477,7 +2500,7 @@ public class TestUtils
* <p>
* Returns a negative random double value.
* </p>
- *
+ *
* @return negative random double value
*/
public static double getNegativeRandomDouble() {
@@ -2492,7 +2515,7 @@ public class TestUtils
* Returns the string representation of a double value which can be
used in
* a DML script.
* </p>
- *
+ *
* @param value
* double value
* @return string representation
@@ -2504,7 +2527,7 @@ public class TestUtils
nf.setMaximumFractionDigits(20);
return nf.format(value);
}
-
+
public static void replaceRandom( double[][] A, int rows, int cols,
double replacement, int len ) {
Random rand = new Random();
for( int i=0; i<len; i++ )
@@ -2523,7 +2546,7 @@ public class TestUtils
* <p>
* Generates a matrix containing easy to debug values in its cells.
* </p>
- *
+ *
* @param rows
* @param cols
* @param bContainsZeros
@@ -2543,45 +2566,45 @@ public class TestUtils
}
return matrix;
}
-
+
public static double[][] round(double[][] data) {
for(int i=0; i<data.length; i++)
for(int j=0; j<data[i].length; j++)
data[i][j]=Math.round(data[i][j]);
return data;
}
-
+
public static double[][] round(double[][] data, int col) {
for(int i=0; i<data.length; i++)
data[i][col]=Math.round(data[i][col]);
return data;
}
-
+
public static MatrixBlock round(MatrixBlock data) {
return DataConverter.convertToMatrixBlock(
round(DataConverter.convertToDoubleMatrix(data)));
}
-
+
public static double[][] floor(double[][] data) {
for(int i=0; i<data.length; i++)
for(int j=0; j<data[i].length; j++)
data[i][j]=Math.floor(data[i][j]);
return data;
}
-
+
public static double[][] ceil(double[][] data) {
for(int i=0; i<data.length; i++)
for(int j=0; j<data[i].length; j++)
data[i][j]=Math.ceil(data[i][j]);
return data;
}
-
+
public static double[][] floor(double[][] data, int col) {
for(int i=0; i<data.length; i++)
data[i][col]=Math.floor(data[i][col]);
return data;
}
-
+
public static double sum(double[][] data, int rows, int cols) {
double sum = 0;
for (int i = 0; i< rows; i++){
@@ -2591,14 +2614,14 @@ public class TestUtils
}
return sum;
}
-
+
public static long computeNNZ(double[][] data) {
long nnz = 0;
for(int i=0; i<data.length; i++)
nnz += UtilFunctions.computeNnz(data[i], 0,
data[i].length);
return nnz;
}
-
+
public static double[][] seq(int from, int to, int incr) {
int len = (int)UtilFunctions.getSeqLength(from, to, incr);
double[][] ret = new double[len][1];
@@ -2606,7 +2629,7 @@ public class TestUtils
ret[i][0] = val;
return ret;
}
-
+
public static void shutdownThreads(Thread... ts) {
for( Thread t : ts )
shutdownThread(t);
@@ -2616,7 +2639,7 @@ public class TestUtils
for( Process t : ts )
shutdownThread(t);
}
-
+
public static void shutdownThread(Thread t) {
// kill the worker
if( t != null ) {
@@ -2642,11 +2665,11 @@ public class TestUtils
}
}
}
-
+
public static String federatedAddress(int port, String input) {
return federatedAddress("localhost", port, input);
}
-
+
public static String federatedAddress(String host, int port, String
input) {
return host + ':' + port + '/' + input;
}
@@ -2988,7 +3011,7 @@ public class TestUtils
return output;
}
}
-
+
public static double[][] generateUnbalancedGLMInputDataX(int rows, int
cols, double logFeatureVarianceDisbalance) {
double[][] X = generateTestMatrix(rows, cols, -1.0, 1.0, 1.0,
34567);
double shift_X = 1.0;
@@ -3000,14 +3023,14 @@ public class TestUtils
}
return X;
}
-
+
public static double[] generateUnbalancedGLMInputDataB(double[][] X,
int cols, double intercept, double avgLinearForm, double stdevLinearForm,
Random r) {
double[] beta_unscaled = new double[cols];
for (int j = 0; j < cols; j++)
beta_unscaled[j] = r.nextGaussian();
return scaleWeights(beta_unscaled, X, intercept, avgLinearForm,
stdevLinearForm);
}
-
+
public static double[][] generateUnbalancedGLMInputDataY(double[][] X,
double[] beta, int rows, int cols, GLMDist glmdist, double intercept, double
dispersion, Random r) {
double[][] y = null;
if (glmdist.is_binom_n_needed())
@@ -3030,7 +3053,7 @@ public class TestUtils
y[i][0] = glmdist.nextGLM(r, eta);
}
}
-
+
return y;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNBFTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNBFTest.java
new file mode 100644
index 0000000..e62ea9f
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNBFTest.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import java.util.HashMap;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+
+@RunWith(value = Parameterized.class)
+public class BuiltinKNNBFTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "knnbf";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinKNNBFTest.class.getSimpleName() + "/";
+
+ private final static String OUTPUT_NAME = "B";
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int query_rows;
+ @Parameterized.Parameter(3)
+ public int query_cols;
+ @Parameterized.Parameter(4)
+ public boolean continuous;
+ @Parameterized.Parameter(5)
+ public int k_value;
+ @Parameterized.Parameter(6)
+ public double sparsity;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data()
+ {
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, query_rows, query_cols, continuous,
k_value, sparsity}
+ {150, 80, 15, 80, true, 21, 0.9}
+ });
+ }
+
+ @Test
+ public void testKNN() {
+ runKNNTest(ExecMode.SINGLE_NODE);
+ }
+
+ private void runKNNTest(ExecMode exec_mode)
+ {
+ ExecMode platform_old = setExecMode(exec_mode);
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ double[][] X = getRandomMatrix(rows, cols, 0, 1, sparsity, 255);
+ double[][] T = getRandomMatrix(query_rows, query_cols, 0, 1, 1,
65);
+
+ double[][] CL = new double[rows][1];
+ for(int counter = 0; counter < rows; counter++)
+ CL[counter][0] = counter + 1;
+
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("T", T, true);
+ writeInputMatrixWithMTD("CL", CL, true);
+
+ // execute reference test
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X=" + input("X"), "in_T=" + input("T"), "in_CL=" +
input("CL"), "in_continuous=" + (continuous ? "1" : "0"), "in_k=" +
Integer.toString(k_value),
+ "out_B=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // execute actual test
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X=" + input("X"), "in_T=" + input("T"),
"in_continuous=" + (continuous ? "1" : "0"), "in_k=" +
Integer.toString(k_value),
+ "out_B=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir("B");
+ HashMap<CellIndex, Double> results =
readDMLMatrixFromOutputDir("B");
+
+ TestUtils.compareMatrices(results, refResults, 0, "Res", "Ref");
+
+ // restore execution mode
+ setExecMode(platform_old);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNTest.java
new file mode 100644
index 0000000..e2f10a3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKNNTest.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import java.util.HashMap;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+
+@RunWith(value = Parameterized.class)
+public class BuiltinKNNTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "knn";
+ private final static String TEST_DIR = "functions/builtin/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BuiltinKNNTest.class.getSimpleName() + "/";
+
+ private final static String OUTPUT_NAME_NNR = "NNR";
+ private final static String OUTPUT_NAME_PR = "PR";
+
+ private final static double TEST_TOLERANCE = 0.15;
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int query_rows;
+ @Parameterized.Parameter(3)
+ public int query_cols;
+ @Parameterized.Parameter(4)
+ public boolean continuous;
+ @Parameterized.Parameter(5)
+ public int k_value;
+ @Parameterized.Parameter(6)
+ public double sparsity;
+
+ @Override
+ public void setUp()
+ {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME_NNR,
OUTPUT_NAME_PR}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data()
+ {
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, query_rows, query_cols, continuous,
k_value, sparsity}
+ {100, 20, 3, 20, true, 3, 1}
+ });
+ }
+
+ @Test
+ @Ignore //TODO add libraries to docker image
+ public void testKNN() {
+ runKNNTest(ExecMode.SINGLE_NODE);
+ }
+
+ private void runKNNTest(ExecMode exec_mode)
+ {
+ ExecMode platform_old = setExecMode(exec_mode);
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // create Train and Test data
+ double[][] X = getRandomMatrix(rows, cols, 0, 1, sparsity, 75);
+ double[][] T = getRandomMatrix(query_rows, query_cols, 0, 1, 1,
65);
+
+ double[][] CL = new double[rows][1];
+ for(int counter = 0; counter < rows; counter++)
+ CL[counter][0] = counter + 1;
+
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("T", T, true);
+ writeInputMatrixWithMTD("CL", CL, true);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X=" + input("X"), "in_T=" + input("T"), "in_CL=" +
input("CL"), "in_continuous=" + (continuous ? "1" : "0"), "in_k=" +
Integer.toString(k_value),
+ "out_NNR=" + output(OUTPUT_NAME_NNR), "out_PR=" +
output(OUTPUT_NAME_PR)};
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = getRCmd(inputDir(), (continuous ? "1" : "0"),
Integer.toString(k_value),
+ expectedDir());
+
+ // execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ // compare test results of RScript with dml script via files
+ HashMap<CellIndex, Double> refNNR =
readRMatrixFromExpectedDir("NNR");
+ HashMap<CellIndex, Double> resNNR =
readDMLMatrixFromOutputDir("NNR");
+
+ TestUtils.compareMatrices(resNNR, refNNR, 0, "ResNNR",
"RefNNR");
+
+ double[][] refPR =
TestUtils.convertHashMapToDoubleArray(readRMatrixFromExpectedDir("PR"));
+ double[][] resPR =
TestUtils.convertHashMapToDoubleArray(readDMLMatrixFromOutputDir("PR"));
+
+ TestUtils.compareMatricesAvgRowDistance(refPR, resPR,
query_rows, query_cols, TEST_TOLERANCE);
+
+ // restore execution mode
+ setExecMode(platform_old);
+ }
+}
diff --git a/src/test/scripts/functions/builtin/knn.R
b/src/test/scripts/functions/builtin/knn.R
new file mode 100644
index 0000000..45ba7c3
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knn.R
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# TODO arguments and order
+args <- commandArgs(TRUE)
+library("Matrix")
+
+# read test data
+data_train <- as.matrix(readMM(paste(args[1], "/X.mtx", sep="")))
+data_test <- as.matrix(readMM(paste(args[1], "/T.mtx", sep="")))
+CL <- as.matrix(readMM(paste(args[1], "/CL.mtx", sep="")))
+
+is_continuous <- as.integer(args[2])
+K <- as.integer(args[3])
+
+library(FNN);
+set.seed(10);
+tmp_data = rbind(data_train, data_test);
+knn_neighbors <- get.knn(tmp_data, k=K);
+knn_neighbors <- (tail(knn_neighbors$nn.index, NROW(data_test)));
+writeMM(as(knn_neighbors, "CsparseMatrix"), paste(args[4], "NNR", sep=""));
+
+
+# ------ training -------
+library(class)
+
+set.seed(10);
+test_pred <- knn(train=data_train, test=data_test, cl=CL, k=K);
+print("test_pred:")
+print(test_pred)
+PR_val <- matrix( , nrow=0, ncol=NCOL(data_test));
+for(i in 1:NROW(data_test)) {
+ PR_val <- rbind(PR_val, data_train[test_pred[i] , ])
+}
+writeMM(as(PR_val, "CsparseMatrix"), paste(args[4], "PR", sep=""));
diff --git a/src/test/scripts/functions/builtin/knn.dml
b/src/test/scripts/functions/builtin/knn.dml
new file mode 100644
index 0000000..8ea5a7e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knn.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($in_X)
+T = read($in_T)
+CL = read($in_CL)
+k = $in_k
+
+[NNR, PR, FI] = knn(Train=X, Test=T, CL=CL, k_value=k, predict_con_tg=1);
+
+PR_val = matrix(0, 0, ncol(T));
+for(i in 1:nrow(T)) {
+ PR_val = rbind(PR_val, X[as.scalar(PR[i]), ]);
+}
+
+write(NNR, $out_NNR);
+write(PR_val, $out_PR);
diff --git a/src/test/scripts/functions/builtin/knnbf.dml
b/src/test/scripts/functions/builtin/knnbf.dml
new file mode 100644
index 0000000..e5ae2de
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knnbf.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($in_X)
+T = read($in_T)
+k = $in_k
+
+NNR = knnbf(X=X, T=T, k_value = k)
+
+write(NNR, $out_B)
diff --git a/src/test/scripts/functions/builtin/knnbfReference.dml
b/src/test/scripts/functions/builtin/knnbfReference.dml
new file mode 100644
index 0000000..994f466
--- /dev/null
+++ b/src/test/scripts/functions/builtin/knnbfReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($in_X)
+T = read($in_T)
+CL = read($in_CL)
+k = $in_k
+
+[NNR, PR, FI] = knn(Train=X, Test=T, CL=CL, k_value=k);
+
+write(NNR, $out_B)
diff --git a/src/test/scripts/installDependencies.R
b/src/test/scripts/installDependencies.R
index b8b2e66..7ae4159 100644
--- a/src/test/scripts/installDependencies.R
+++ b/src/test/scripts/installDependencies.R
@@ -58,6 +58,8 @@ custom_install("mice");
custom_install("mclust");
custom_install("dbscan");
custom_install("imputeTS");
+custom_install("FNN");
+custom_install("class");
print("Installation Done")