This is an automated email from the ASF dual-hosted git repository.

qkou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9cdac22  R-package RNN refactor (#7476)
9cdac22 is described below

commit 9cdac22dc5ab2090ea9090ddedd25dec712b5b55
Author: jeremiedb <jeremi...@users.noreply.github.com>
AuthorDate: Wed Sep 20 13:07:42 2017 -0400

    R-package RNN refactor (#7476)
    
    * R-package RNN refactor
---
 R-package/R/model.rnn.R                            | 339 ++++++++++++++++++++
 R-package/R/mx.io.bucket.iter.R                    | 110 +++++++
 R-package/R/rnn.R                                  | 342 ---------------------
 R-package/R/rnn.graph.R                            | 283 +++++++++++++++++
 R-package/R/rnn.infer.R                            | 177 +++++++++++
 R-package/R/rnn_model.R                            | 258 ----------------
 R-package/R/viz.graph.R                            |   4 +-
 R-package/tests/testthat/test_lstm.R               |  57 ----
 .../rnn/bucket_R/data_preprocessing_seq_to_one.R   | 176 +++++++++++
 9 files changed, 1086 insertions(+), 660 deletions(-)

diff --git a/R-package/R/model.rnn.R b/R-package/R/model.rnn.R
new file mode 100644
index 0000000..8f3ab8c
--- /dev/null
+++ b/R-package/R/model.rnn.R
@@ -0,0 +1,339 @@
+# Internal function to do multiple device training on RNN
+mx.model.train.buckets <- function(symbol, ctx, train.data, eval.data, 
+                                   dlist, arg.params, aux.params, 
+                                   grad.req, arg.update.idx, 
+                                   begin.round, end.round, optimizer, metric, 
+                                   epoch.end.callback, batch.end.callback, 
kvstore, verbose = TRUE) {
+  
+  ndevice <- length(ctx)
+  if (verbose) 
+    message(paste0("Start training with ", ndevice, " devices"))
+  
+  input.names <- names(dlist)
+  arg.params.names <- names(arg.params)
+  
+  if (is.list(symbol)) sym_ini <- symbol[[names(train.data$bucketID)]] else 
sym_ini <- symbol
+  
+  slices <- lapply(1:ndevice, function(i) {
+    sapply(names(dlist), function(n) mx.nd.split(data=dlist[[n]], num_outputs 
= ndevice, axis = 0, squeeze_axis = F))
+  })
+  
+  train.execs <- lapply(1:ndevice, function(i) {
+    s <- slices[[i]]
+    mx.symbol.bind(symbol = sym_ini, arg.arrays = c(s, 
arg.params)[arg.update.idx], 
+                           aux.arrays = aux.params, ctx = ctx[[i]], grad.req = 
grad.req)
+  })
+  
+  # KVStore related stuffs
+  params.index <- as.integer(
+    mx.util.filter.null(
+      lapply(1:length(train.execs[[1]]$ref.grad.arrays), function(k) {
+        if (!is.null(train.execs[[1]]$ref.grad.arrays[[k]])) k else NULL}
+      )))
+  
+  update.on.kvstore <- FALSE
+  if (!is.null(kvstore) && kvstore$update.on.kvstore) {
+    update.on.kvstore <- TRUE
+    kvstore$set.optimizer(optimizer)
+  } else {
+    updaters <- lapply(1:ndevice, function(i) {
+      mx.opt.get.updater(optimizer, train.execs[[i]]$ref.arg.arrays)
+    })
+  }
+  
+  if (!is.null(kvstore)) {
+    kvstore$init(params.index, train.execs[[1]]$ref.arg.arrays[params.index])
+  }
+  
+  # train over specified number of epochs
+  for (iteration in begin.round:end.round) {
+    nbatch <- 0
+    if (!is.null(metric)) {
+      train.metric <- metric$init()
+    }
+    train.data$reset()
+    while (train.data$iter.next()) {
+      
+      # Get iterator data
+      dlist <- train.data$value()[input.names]
+      
+      # Slice inputs for multi-devices
+      slices <- lapply(1:ndevice, function(i) {
+        sapply(names(dlist), function(n) mx.nd.split(data=dlist[[n]], 
num_outputs = ndevice, axis = 0, squeeze_axis = F))
+      })
+      
+      # Assign input to each executor - bug on inference if using BatchNorm
+      if (is.list(symbol)) {
+        train.execs <- lapply(1:ndevice, function(i) {
+          s <- slices[[i]]
+          mx.symbol.bind(symbol = symbol[[names(train.data$bucketID)]], 
+                                 arg.arrays = c(s, 
train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
+                                 aux.arrays = train.execs[[i]]$aux.arrays, ctx 
= ctx[[i]], grad.req = grad.req)
+        })
+      } else {
+        for (i in 1:ndevice) {
+          s <- slices[[i]]
+          mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
+        }
+      }
+      
+      for (texec in train.execs) {
+        mx.exec.forward(texec, is.train = TRUE)
+      }
+      
+      out.preds <- lapply(train.execs, function(texec) {
+        mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
+      })
+      
+      for (texec in train.execs) {
+        mx.exec.backward(texec)
+      }
+      
+      if (!is.null(kvstore)) {
+        # push the gradient
+        kvstore$push(params.index, lapply(train.execs, function(texec) {
+          texec$ref.grad.arrays[params.index]
+        }), -params.index)
+      }
+      if (update.on.kvstore) {
+        # pull back weight
+        kvstore$pull(params.index, lapply(train.execs, function(texec) {
+          texec$ref.arg.arrays[params.index]
+        }), -params.index)
+      } else {
+        # pull back gradient sums
+        if (!is.null(kvstore)) {
+          kvstore$pull(params.index, lapply(train.execs, function(texec) {
+            texec$ref.grad.arrays[params.index]
+          }), -params.index)
+        }
+        arg.blocks <- lapply(1:ndevice, function(i) {
+          updaters[[i]](train.execs[[i]]$ref.arg.arrays, 
train.execs[[i]]$ref.grad.arrays)
+        })
+        for (i in 1:ndevice) {
+          mx.exec.update.arg.arrays(train.execs[[i]], arg.blocks[[i]], 
skip.null = TRUE)
+        }
+      }
+      
+      # Update the evaluation metrics
+      if (!is.null(metric)) {
+        for (i in 1:ndevice) {
+          train.metric <- metric$update(label = 
slices[[i]][[length(slices[[i]])]], 
+                                        pred = out.preds[[i]], state = 
train.metric)
+        }
+      }
+      
+      nbatch <- nbatch + 1
+      
+      if (!is.null(batch.end.callback)) {
+        batch.end.callback(iteration, nbatch, environment())
+      }
+    }
+    
+    if (!is.null(metric)) {
+      result <- metric$get(train.metric)
+      if (verbose) 
+        message(paste0("[", iteration, "] Train-", result$name, "=", 
result$value))
+    }
+    
+    if (!is.null(eval.data)) {
+      if (!is.null(metric)) {
+        eval.metric <- metric$init()
+      }
+      eval.data$reset()
+      while (eval.data$iter.next()) {
+        
+        # Get iterator data
+        dlist <- eval.data$value()[input.names]
+        
+        # Slice input to multiple devices
+        slices <- lapply(1:ndevice, function(i) {
+          sapply(names(dlist), function(n) mx.nd.split(data=dlist[[n]], 
num_outputs = ndevice, axis = 0, squeeze_axis = F))
+        })
+        
+        # Assign input to each executor - bug on inference if using BatchNorm
+        if (is.list(symbol)) {
+          train.execs <- lapply(1:ndevice, function(i) {
+            s <- slices[[i]]
+            mx.symbol.bind(symbol = symbol[[names(eval.data$bucketID)]], 
+                                   arg.arrays = c(s, 
train.execs[[i]]$arg.arrays[arg.params.names])[arg.update.idx],
+                                   aux.arrays = train.execs[[i]]$aux.arrays, 
ctx = ctx[[i]], grad.req = grad.req)
+          })
+        } else {
+          for (i in 1:ndevice) {
+            s <- slices[[i]]
+            mx.exec.update.arg.arrays(train.execs[[i]], s, match.name=TRUE)
+          }
+        }
+        
+        for (texec in train.execs) {
+          mx.exec.forward(texec, is.train = FALSE)
+        }
+        
+        # copy outputs to CPU
+        out.preds <- lapply(train.execs, function(texec) {
+          mx.nd.copyto(texec$ref.outputs[[1]], mx.cpu())
+        })
+        
+        if (!is.null(metric)) {
+          for (i in 1:ndevice) {
+            eval.metric <- metric$update(slices[[i]][[length(slices[[i]])]], 
+                                         out.preds[[i]], eval.metric)
+          }
+        }
+      }
+      
+      if (!is.null(metric)) {
+        result <- metric$get(eval.metric)
+        if (verbose) {
+          message(paste0("[", iteration, "] Validation-", result$name, "=", 
+                         result$value))
+        }
+      }
+    } else {
+      eval.metric <- NULL
+    }
+    # get the model out
+    model <- mx.model.extract.model(sym_ini, train.execs)
+    
+    epoch_continue <- TRUE
+    if (!is.null(epoch.end.callback)) {
+      epoch_continue <- epoch.end.callback(iteration, 0, environment(), 
verbose = verbose)
+    }
+    
+    if (!epoch_continue) {
+      break
+    }
+  }
+  return(model)
+}
+
+
+# 
+#' Train RNN with bucket support
+#'
+#' @param symbol Symbol or list of Symbols representing the model
+#' @param train.data Training data created by mx.io.bucket.iter
+#' @param eval.data Evaluation data created by mx.io.bucket.iter
+#' @param num.round int, number of epoch
+#' @param initializer
+#' @param optimizer
+#' @param batch.end.callback
+#' @param epoch.end.callback
+#' @param begin.round
+#' @param metric
+#' @param ctx
+#' @param kvstore
+#' @param verbose
+#'
+#' @export
+mx.model.buckets <- function(symbol, train.data, eval.data = NULL, metric = 
NULL, 
+                             arg.params = NULL, aux.params = NULL, 
fixed.params = NULL, 
+                             num.round = 1, begin.round = 1, 
+                             initializer = mx.init.uniform(0.01), optimizer = 
"sgd", ctx = NULL, 
+                             batch.end.callback = NULL, epoch.end.callback = 
NULL, 
+                             kvstore = "local", verbose = TRUE) {
+  
+  if (!train.data$iter.next()) {
+    train.data$reset()
+    if (!train.data$iter.next()) 
+      stop("Empty train.data")
+  }
+  
+  if (!is.null(eval.data)) {
+    if (!eval.data$iter.next()) {
+      eval.data$reset()
+      if (!eval.data$iter.next()) 
+        stop("Empty eval.data")
+    }
+  }
+  
+  if (is.null(ctx)) 
+    ctx <- mx.ctx.default()
+  if (is.mx.context(ctx)) {
+    ctx <- list(ctx)
+  }
+  if (!is.list(ctx)) 
+    stop("ctx must be mx.context or list of mx.context")
+  if (is.character(optimizer)) {
+    if (is.numeric(input.shape)) {
+      ndim <- length(input.shape)
+      batchsize <- input.shape[[ndim]]
+    } else {
+      ndim <- length(input.shape[[1]])
+      batchsize <- input.shape[[1]][[ndim]]
+    }
+    optimizer <- mx.opt.create(optimizer, rescale.grad = (1/batchsize), ...)
+  }
+  
+  if (is.list(symbol)) sym_ini <- symbol[[names(train.data$bucketID)]] else 
sym_ini <- symbol
+  
+  arguments <- sym_ini$arguments
+  input.names <- intersect(names(train.data$value()), arguments)
+  
+  input.shape <- sapply(input.names, function(n) {
+    dim(train.data$value()[[n]])
+  }, simplify = FALSE)
+  
+  shapes <- sym_ini$infer.shape(input.shape)
+  
+  # assign arg.params and aux.params arguments to arg.params.input and 
aux.params.input
+  arg.params.input <- arg.params
+  aux.params.input <- aux.params
+  
+  # initialize all arguments with zeros
+  arg.params <- lapply(shapes$arg.shapes, function(shape) {
+    mx.nd.zeros(shape = shape, ctx = mx.cpu())
+  })
+  
+  # initialize input parameters
+  dlist <- arg.params[input.names]
+  
+  # initialize parameters - only argument ending with _weight and _bias are 
initialized
+  arg.params.ini <- mx.init.create(initializer = initializer, shape.array = 
shapes$arg.shapes, ctx = mx.cpu(), skip.unknown = TRUE)
+  
+  # assign initilized parameters to arg.params
+  arg.params[names(arg.params.ini)] <- arg.params.ini
+  
+  # assign input params to arg.params
+  arg.params[names(arg.params.input)] <- arg.params.input
+  
+  # remove input params from arg.params
+  arg.params[input.names] <- NULL
+  
+  # Grad request
+  grad.req <- rep("null", length(arguments))
+  grad.req.write <- arguments %in% setdiff(names(arg.params.ini), fixed.params)
+  grad.req[grad.req.write] <- "write"
+  
+  # Arg array order
+  update_names <- c(input.names, names(arg.params))
+  arg.update.idx <- match(arguments, update_names)
+  
+  # aux parameters setup
+  aux.params <- lapply(shapes$aux.shapes, function(shape) {
+    mx.nd.zeros(shape = shape, ctx = mx.cpu())
+  })
+  
+  aux.params.ini <- mx.init.create(initializer, shapes$aux.shapes, ctx = 
mx.cpu(), skip.unknown = FALSE)
+  if (length(aux.params) > 0) {
+    aux.params[names(aux.params.ini)] <- aux.params.ini
+  } else aux.params <- NULL
+  
+  aux.params[names(aux.params.input)] <- aux.params.input
+  
+  # kvstore initialization
+  kvstore <- mx.model.create.kvstore(kvstore, params$arg.params, length(ctx), 
+                                             verbose = verbose)
+  
+  ### Execute training
+  model <- mx.model.train.buckets(symbol = symbol, ctx = ctx,  train.data = 
train.data, eval.data = eval.data, 
+                                  dlist = dlist,  arg.params = arg.params, 
aux.params = aux.params, 
+                                  grad.req = grad.req, arg.update.idx = 
arg.update.idx, 
+                                  optimizer = optimizer, metric = metric, 
+                                  begin.round = begin.round, end.round = 
num.round, 
+                                  batch.end.callback = batch.end.callback, 
epoch.end.callback = epoch.end.callback, 
+                                  kvstore = kvstore, verbose = verbose)
+  
+  return(model)
+}
diff --git a/R-package/R/mx.io.bucket.iter.R b/R-package/R/mx.io.bucket.iter.R
new file mode 100644
index 0000000..8e5ab59
--- /dev/null
+++ b/R-package/R/mx.io.bucket.iter.R
@@ -0,0 +1,110 @@
+
+BucketIter <- setRefClass("BucketIter", fields = c("buckets", "bucket.names", 
"batch.size", 
+                                                   "data.mask.element", 
"shuffle", "bucket.plan", "bucketID", "epoch", "batch", "batch.per.bucket", 
+                                                   "last.batch.pad", 
"batch.per.epoch", "seed"), 
+                          methods = list(
+                            initialize = function(buckets, 
+                                                  batch.size, 
data.mask.element = 0, shuffle = FALSE, seed = 123) {
+                              .self$buckets <- buckets
+                              .self$bucket.names <- names(.self$buckets)
+                              .self$batch.size <- batch.size
+                              .self$data.mask.element <- data.mask.element
+                              .self$epoch <- 0
+                              .self$batch <- 0
+                              .self$shuffle <- shuffle
+                              .self$batch.per.bucket <- 0
+                              .self$batch.per.epoch <- 0
+                              .self$bucket.plan <- NULL
+                              .self$bucketID <- NULL
+                              .self$seed <- seed
+                              .self
+                            }, reset = function() {
+                              buckets_nb <- length(bucket.names)
+                              buckets_id <- 1:buckets_nb
+                              buckets.size <- sapply(.self$buckets, 
function(x) {
+                                dim(x$data)[length(dim(x$data)) - 1]
+                              })
+                              .self$batch.per.bucket <- 
ceiling(buckets.size/.self$batch.size)
+                              .self$last.batch.pad <- .self$batch.size - 
buckets.size %% .self$batch.size
+                              .self$last.batch.pad[.self$last.batch.pad == 
.self$batch.size] <- 0
+                              
+                              .self$batch.per.epoch <- 
sum(.self$batch.per.bucket)
+                              # Number of batches per epoch given the 
batch.size
+                              .self$batch.per.epoch <- 
sum(.self$batch.per.bucket)
+                              .self$epoch <- .self$epoch + 1
+                              .self$batch <- 0
+                              
+                              if (.self$shuffle) {
+                                set.seed(.self$seed)
+                                bucket_plan_names <- 
sample(rep(names(.self$batch.per.bucket), times = .self$batch.per.bucket))
+                                .self$bucket.plan <- ave(bucket_plan_names == 
bucket_plan_names, bucket_plan_names, 
+                                                         FUN = cumsum)
+                                names(.self$bucket.plan) <- bucket_plan_names
+                                ### Return first BucketID at reset for 
initialization of the model
+                                .self$bucketID <- .self$bucket.plan[1]
+                                
+                                .self$buckets <- lapply(.self$buckets, 
function(x) {
+                                  shuffle_id <- 
sample(dim(x$data)[length(dim(x$data)) - 1])
+                                  if (length(dim(x$label)) == 0) {
+                                    list(data = x$data[shuffle_id, ], label = 
x$label[shuffle_id])
+                                  } else {
+                                    list(data = x$data[shuffle_id, ], label = 
x$label[shuffle_id, ])
+                                  }
+                                })
+                              } else {
+                                bucket_plan_names <- 
rep(names(.self$batch.per.bucket), times = .self$batch.per.bucket)
+                                .self$bucket.plan <- ave(bucket_plan_names == 
bucket_plan_names, bucket_plan_names, 
+                                                         FUN = cumsum)
+                                names(.self$bucket.plan) <- bucket_plan_names
+                              }
+                            }, iter.next = function() {
+                              .self$batch <- .self$batch + 1
+                              .self$bucketID <- .self$bucket.plan[batch]
+                              if (.self$batch > .self$batch.per.epoch) {
+                                return(FALSE)
+                              } else {
+                                return(TRUE)
+                              }
+                            }, value = function() {
+                              # bucketID is a named integer: the integer 
indicates the batch id for the given
+                              # bucket (used to fetch appropriate samples 
within the bucket) the name is the a
+                              # character containing the sequence length of 
the bucket (used to unroll the rnn
+                              # to appropriate sequence length)
+                              idx <- (.self$bucketID - 1) * (.self$batch.size) 
+ (1:batch.size)
+                              
+                              ### reuse first idx for padding
+                              if (bucketID == 
.self$batch.per.bucket[names(.self$bucketID)] & 
!.self$last.batch.pad[names(.self$bucketID)] == 0) {
+                                idx <- c(idx[1:(.self$batch.size - 
.self$last.batch.pad[names(.self$bucketID)])], 
1:(.self$last.batch.pad[names(.self$bucketID)]))
+                              }
+                              
+                              data <- 
.self$buckets[[names(.self$bucketID)]]$data[idx, , drop = F]
+                              seq.mask <- as.integer(names(bucketID)) - 
apply(data==.self$data.mask.element, 1, sum)
+                              if 
(length(dim(.self$buckets[[names(.self$bucketID)]]$label)) == 0) {
+                                label <- 
.self$buckets[[names(.self$bucketID)]]$label[idx]
+                              } else {
+                                label <- 
.self$buckets[[names(.self$bucketID)]]$label[idx, , drop = F]
+                              }
+                              return(list(data = mx.nd.array(data), seq.mask = 
mx.nd.array(seq.mask), 
+                                          label = mx.nd.array(label)))
+                            }, num.pad = function() {
+                              if (bucketID == 
.self$batch.per.bucket[names(.self$bucketID)] & 
!.self$last.batch.pad[names(.self$bucketID)] == 0){
+                                
return(.self$last.batch.pad[names(.self$bucketID)])
+                              } else return(0)
+                            }, finalize = function() {
+                            }))
+
+# 
+#' Create Bucket Iter
+#'
+#' @param buckets The data array.
+#' @param batch.size The batch size used to pack the array.
+#' @param data.mask.element The element to mask
+#' @param shuffle Whether shuffle the data
+#' @param seed The random seed
+#'
+#' @export
+mx.io.bucket.iter <- function(buckets, batch.size, data.mask.element = 0, 
shuffle = FALSE, 
+                              seed = 123) {
+  return(BucketIter$new(buckets = buckets, batch.size = batch.size, 
data.mask.element = data.mask.element, 
+                        shuffle = shuffle, seed = seed))
+}
diff --git a/R-package/R/rnn.R b/R-package/R/rnn.R
deleted file mode 100644
index b89559a..0000000
--- a/R-package/R/rnn.R
+++ /dev/null
@@ -1,342 +0,0 @@
-# rnn cell symbol
-rnn <- function(num.hidden, indata, prev.state, param, seqidx, 
-                layeridx, dropout=0., batch.norm=FALSE) {
-    if (dropout > 0. )
-        indata <- mx.symbol.Dropout(data=indata, p=dropout)
-    i2h <- mx.symbol.FullyConnected(data=indata,
-                                    weight=param$i2h.weight,
-                                    bias=param$i2h.bias,
-                                    num.hidden=num.hidden,
-                                    name=paste0("t", seqidx, ".l", layeridx, 
".i2h"))
-    h2h <- mx.symbol.FullyConnected(data=prev.state$h,
-                                    weight=param$h2h.weight,
-                                    bias=param$h2h.bias,
-                                    num.hidden=num.hidden,
-                                    name=paste0("t", seqidx, ".l", layeridx, 
".h2h"))
-    hidden <- i2h + h2h
-
-    hidden <- mx.symbol.Activation(data=hidden, act.type="tanh")
-    if (batch.norm)
-        hidden <- mx.symbol.BatchNorm(data=hidden)
-    return (list(h=hidden))
-}
-
-# unrolled rnn network
-rnn.unroll <- function(num.rnn.layer, seq.len, input.size, num.hidden, 
-                       num.embed, num.label, dropout=0., batch.norm=FALSE) {
-    embed.weight <- mx.symbol.Variable("embed.weight")
-    cls.weight <- mx.symbol.Variable("cls.weight")
-    cls.bias <- mx.symbol.Variable("cls.bias")
-    param.cells <- lapply(1:num.rnn.layer, function(i) {
-        cell <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, 
".i2h.weight")),
-                     i2h.bias = mx.symbol.Variable(paste0("l", i, 
".i2h.bias")),
-                     h2h.weight = mx.symbol.Variable(paste0("l", i, 
".h2h.weight")),
-                     h2h.bias = mx.symbol.Variable(paste0("l", i, 
".h2h.bias")))
-        return (cell)
-    })
-    last.states <- lapply(1:num.rnn.layer, function(i) {
-        state <- list(h=mx.symbol.Variable(paste0("l", i, ".init.h")))
-        return (state)
-    })
-
-    # embeding layer
-    label <- mx.symbol.Variable("label")
-    data <- mx.symbol.Variable("data")
-    embed <- mx.symbol.Embedding(data=data, input_dim=input.size,
-                                 weight=embed.weight, output_dim=num.embed, 
name="embed")
-    wordvec <- mx.symbol.SliceChannel(data=embed, num_outputs=seq.len, 
squeeze_axis=1)
-
-    last.hidden <- list()
-    for (seqidx in 1:seq.len) { 
-        hidden <- wordvec[[seqidx]]
-        # stack RNN
-        for (i in 1:num.rnn.layer) {
-            dp <- ifelse(i==1, 0, dropout)
-            next.state <- rnn(num.hidden, indata=hidden,
-                              prev.state=last.states[[i]],
-                              param=param.cells[[i]],
-                              seqidx=seqidx, layeridx=i, 
-                              dropout=dp, batch.norm=batch.norm)
-            hidden <- next.state$h
-            last.states[[i]] <- next.state
-        }
-        # decoder
-        if (dropout > 0.)
-            hidden <- mx.symbol.Dropout(data=hidden, p=dropout)
-        last.hidden <- c(last.hidden, hidden)
-    }
-    last.hidden$dim <- 0
-    last.hidden$num.args <- seq.len
-    concat <-mxnet:::mx.varg.symbol.Concat(last.hidden)
-    fc <- mx.symbol.FullyConnected(data=concat,
-                                   weight=cls.weight,
-                                   bias=cls.bias,
-                                   num.hidden=num.label)
-    label <- mx.symbol.transpose(data=label)
-    label <- mx.symbol.Reshape(data=label, target.shape=c(0))
-
-    loss.all <- mx.symbol.SoftmaxOutput(data=fc, label=label, name="sm")
-    return (loss.all)
-}
-
-# rnn inference model symbol
-rnn.inference.symbol <- function(num.rnn.layer, seq.len, input.size, 
num.hidden, 
-                                 num.embed, num.label, dropout=0., 
batch.norm=FALSE) {
-    seqidx <- 0
-    embed.weight <- mx.symbol.Variable("embed.weight")
-    cls.weight <- mx.symbol.Variable("cls.weight")
-    cls.bias <- mx.symbol.Variable("cls.bias")
-    param.cells <- lapply(1:num.rnn.layer, function(i) {
-        cell <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, 
".i2h.weight")),
-                     i2h.bias = mx.symbol.Variable(paste0("l", i, 
".i2h.bias")),
-                     h2h.weight = mx.symbol.Variable(paste0("l", i, 
".h2h.weight")),
-                     h2h.bias = mx.symbol.Variable(paste0("l", i, 
".h2h.bias")))
-        return (cell)
-    })
-    last.states <- lapply(1:num.rnn.layer, function(i) {
-        state <- list(h=mx.symbol.Variable(paste0("l", i, ".init.h")))
-        return (state)
-    })
-
-    # embeding layer
-    data <- mx.symbol.Variable("data")
-    hidden <- mx.symbol.Embedding(data=data, input_dim=input.size,
-                                 weight=embed.weight, output_dim=num.embed, 
name="embed")
-    # stack RNN        
-    for (i in 1:num.rnn.layer) {
-        dp <- ifelse(i==1, 0, dropout)
-        next.state <- rnn(num.hidden, indata=hidden,
-                          prev.state=last.states[[i]],
-                          param=param.cells[[i]],
-                          seqidx=seqidx, layeridx=i, 
-                          dropout=dp, batch.norm=batch.norm)
-        hidden <- next.state$h
-        last.states[[i]] <- next.state
-    }
-    # decoder
-    if (dropout > 0.)
-        hidden <- mx.symbol.Dropout(data=hidden, p=dropout)
-
-    fc <- mx.symbol.FullyConnected(data=hidden,
-                                   weight=cls.weight,
-                                   bias=cls.bias,
-                                   num_hidden=num.label)
-    sm <- mx.symbol.SoftmaxOutput(data=fc, name='sm')
-    unpack.h <- lapply(1:num.rnn.layer, function(i) {
-        state <- last.states[[i]]
-        state.h <- mx.symbol.BlockGrad(state$h, name=paste0("l", i, ".last.h"))
-        return (state.h)
-    })
-    list.all <- c(sm, unpack.h)
-    return (mx.symbol.Group(list.all))
-}
-
-#' Training RNN Unrolled Model
-#'
-#' @param train.data mx.io.DataIter or list(data=R.array, label=R.array)
-#'      The Training set.
-#' @param eval.data mx.io.DataIter or list(data=R.array, label=R.array), 
optional
-#'      The validation set used for validation evaluation during the progress.
-#' @param num.rnn.layer integer
-#'      The number of the layer of rnn.
-#' @param seq.len integer
-#'      The length of the input sequence.
-#' @param num.hidden integer
-#'      The number of hidden nodes.
-#' @param num.embed integer
-#'      The output dim of embedding.
-#' @param num.label  integer
-#'      The number of labels.
-#' @param batch.size integer
-#'      The batch size used for R array training.
-#' @param input.size integer
-#'       The input dim of one-hot encoding of embedding
-#' @param ctx mx.context, optional
-#'      The device used to perform training.
-#' @param num.round integer, default=10
-#'      The number of iterations over training data to train the model.
-#' @param update.period integer, default=1
-#'      The number of iterations to update parameters during training period.
-#' @param initializer initializer object. default=mx.init.uniform(0.01)
-#'      The initialization scheme for parameters.
-#' @param dropout float, default=0
-#'      A number in [0,1) containing the dropout ratio from the last hidden 
layer to the output layer.
-#' @param optimizer string, default="sgd"
-#'      The optimization method.
-#' @param batch.norm boolean, default=FALSE
-#'      Whether to use batch normalization.
-#' @param ... other parameters passing to \code{mx.rnn}/.
-#' @return model A trained rnn unrolled model.
-#'
-#' @export
-mx.rnn <- function( train.data, eval.data=NULL,
-                    num.rnn.layer, seq.len,
-                    num.hidden, num.embed, num.label,
-                    batch.size, input.size,
-                    ctx=mx.ctx.default(),
-                    num.round=10, update.period=1,
-                    initializer=mx.init.uniform(0.01),
-                    dropout=0, optimizer='sgd',
-                    batch.norm=FALSE,
-                    ...) {
-    # check data and change data into iterator
-    train.data <- check.data(train.data, batch.size, TRUE)
-    eval.data <- check.data(eval.data, batch.size, FALSE)
-
-    # get unrolled rnn symbol
-    rnn.sym <- rnn.unroll( num.rnn.layer=num.rnn.layer,
-                           num.hidden=num.hidden,
-                           seq.len=seq.len,
-                           input.size=input.size,
-                           num.embed=num.embed,
-                           num.label=num.label,
-                           dropout=dropout,
-                           batch.norm=batch.norm)
-    init.states.name <- lapply(1:num.rnn.layer, function(i) {
-        state <- paste0("l", i, ".init.h")
-        return (state)
-    })
-    # set up rnn model
-    model <- setup.rnn.model(rnn.sym=rnn.sym,
-                             ctx=ctx,
-                             num.rnn.layer=num.rnn.layer,
-                             seq.len=seq.len,
-                             num.hidden=num.hidden,
-                             num.embed=num.embed,
-                             num.label=num.label,
-                             batch.size=batch.size,
-                             input.size=input.size,
-                             init.states.name=init.states.name,
-                             initializer=initializer,
-                             dropout=dropout)
-    # train rnn model
-    model <- train.rnn( model, train.data, eval.data,
-                        num.round=num.round,
-                        update.period=update.period,
-                        ctx=ctx,
-                        init.states.name=init.states.name,
-                        ...)
-    # change model into MXFeedForwardModel
-    model <- list(symbol=model$symbol, 
arg.params=model$rnn.exec$ref.arg.arrays, 
aux.params=model$rnn.exec$ref.aux.arrays)
-    return(structure(model, class="MXFeedForwardModel"))
-}
-
-#' Create a RNN Inference Model
-#'
-#' @param num.rnn.layer integer
-#'      The number of the layer of rnn.
-#' @param input.size integer
-#'       The input dim of one-hot encoding of embedding
-#' @param num.hidden integer
-#'      The number of hidden nodes.
-#' @param num.embed integer
-#'      The output dim of embedding.
-#' @param num.label  integer
-#'      The number of labels.
-#' @param batch.size integer, default=1
-#'      The batch size used for R array training.
-#' @param arg.params list
-#'      The batch size used for R array training.
-#' @param ctx mx.context, optional
-#'      Model parameter, list of name to NDArray of net's weights.
-#' @param dropout float, default=0
-#'      A number in [0,1) containing the dropout ratio from the last hidden 
layer to the output layer.
-#' @param batch.norm boolean, default=FALSE
-#'      Whether to use batch normalization.
-#' @return model list(rnn.exec=integer, symbol=mxnet symbol, 
num.rnn.layer=integer, num.hidden=integer, seq.len=integer, batch.size=integer, 
num.embed=integer) 
-#'      A rnn inference model.
-#'
-#' @export
-mx.rnn.inference <- function( num.rnn.layer,
-                              input.size,
-                              num.hidden,
-                              num.embed,
-                              num.label,
-                              batch.size=1,
-                              arg.params,
-                              ctx=mx.cpu(),
-                              dropout=0.,
-                              batch.norm=FALSE) {
-    sym <- rnn.inference.symbol( num.rnn.layer=num.rnn.layer,
-                                 input.size=input.size,
-                                 num.hidden=num.hidden,
-                                 num.embed=num.embed,
-                                 num.label=num.label,
-                                 dropout=dropout,
-                                 batch.norm=batch.norm)
-    # init.states.name <- c()
-    # for (i in 1:num.rnn.layer) {
-    #     init.states.name <- c(init.states.name, paste0("l", i, ".init.c"))
-    #     init.states.name <- c(init.states.name, paste0("l", i, ".init.h"))
-    # }
-    init.states.name <- lapply(1:num.rnn.layer, function(i) {
-        state <- paste0("l", i, ".init.h")
-        return (state)
-    })
-    
-    seq.len <- 1
-    # set up rnn model
-    model <- setup.rnn.model(rnn.sym=sym,
-                             ctx=ctx,
-                             num.rnn.layer=num.rnn.layer,
-                             seq.len=seq.len,
-                             num.hidden=num.hidden,
-                             num.embed=num.embed,
-                             num.label=num.label,
-                             batch.size=batch.size,
-                             input.size=input.size,
-                             init.states.name=init.states.name,
-                             initializer=mx.init.uniform(0.01),
-                             dropout=dropout)
-    arg.names <- names(model$rnn.exec$ref.arg.arrays)
-    for (k in names(arg.params)) {
-        if ((k %in% arg.names) && is.param.name(k) ) {
-            rnn.input <- list()
-            rnn.input[[k]] <- arg.params[[k]]
-            mx.exec.update.arg.arrays(model$rnn.exec, rnn.input, 
match.name=TRUE)
-        }
-    }
-    init.states <- list()
-    for (i in 1:num.rnn.layer) {
-        init.states[[paste0("l", i, ".init.h")]] <- 
model$rnn.exec$ref.arg.arrays[[paste0("l", i, ".init.h")]]*0
-    }
-    mx.exec.update.arg.arrays(model$rnn.exec, init.states, match.name=TRUE)
-
-    return (model)
-}
-
-#' Using forward function to predict in rnn inference model
-#'
-#' @param model rnn model
-#'      A rnn inference model
-#' @param input.data, array.matrix
-#'      The input data for forward function
-#' @param new.seq boolean, default=FALSE
-#'      Whether the input is the start of a new sequence
-#'
-#' @return result A list(prob=prob, model=model) containing the result 
probability of each label and the model.
-#'
-#' @export
-mx.rnn.forward <- function(model, input.data, new.seq=FALSE) {
-    if (new.seq == TRUE) {
-        init.states <- list()
-        for (i in 1:model$num.rnn.layer) {
-            init.states[[paste0("l", i, ".init.h")]] <- 
model$rnn.exec$ref.arg.arrays[[paste0("l", i, ".init.h")]]*0
-        }
-        mx.exec.update.arg.arrays(model$rnn.exec, init.states, match.name=TRUE)
-    }
-    dim(input.data) <- c(model$batch.size)
-    data <- list(data=mx.nd.array(input.data))
-    mx.exec.update.arg.arrays(model$rnn.exec, data, match.name=TRUE)
-    mx.exec.forward(model$rnn.exec, is.train=FALSE)
-    init.states <- list()
-    for (i in 1:model$num.rnn.layer) {
-        init.states[[paste0("l", i, ".init.h")]] <- 
model$rnn.exec$ref.outputs[[paste0("l", i, ".last.h_output")]]
-    }
-    mx.exec.update.arg.arrays(model$rnn.exec, init.states, match.name=TRUE)
-    #print (model$rnn.exec$ref)
-    prob <- model$rnn.exec$ref.outputs[["sm_output"]]
-    print ("prob")
-    print (prob)
-    return (list(prob=prob, model=model))
-}
diff --git a/R-package/R/rnn.graph.R b/R-package/R/rnn.graph.R
new file mode 100644
index 0000000..11e5ef5
--- /dev/null
+++ b/R-package/R/rnn.graph.R
@@ -0,0 +1,283 @@
+# 
+#' Generate a RNN symbolic model - requires CUDA
+#' 
+#' @param config Either seq-to-one or one-to-one
+#' @param cell.type Type of RNN cell: either gru or lstm
+#' @param num.rnn.layer int, number of stacked layers
+#' @param num.hidden int, size of the state in each RNN layer
+#' @param num.embed  int, dimension of the embedding vectors
+#' @param num.label int, number of categories in labels
+#' @param input.size int, number of levels in the data
+#' @param dropout
+#' 
+#' @export
+rnn.graph <- function(num.rnn.layer, 
+                      input.size,
+                      num.embed, 
+                      num.hidden,
+                      num.label,
+                      dropout = 0,
+                      ignore_label = -1,
+                      config,
+                      cell.type,
+                      masking = F,
+                      output_last_state = F) {
+  
+  # define input arguments
+  label <- mx.symbol.Variable("label")
+  data <- mx.symbol.Variable("data")
+  seq.mask <- mx.symbol.Variable("seq.mask")
+  
+  embed.weight <- mx.symbol.Variable("embed.weight")
+  rnn.params.weight <- mx.symbol.Variable("rnn.params.weight")
+  
+  rnn.state <- mx.symbol.Variable("rnn.state")
+  
+  if (cell.type == "lstm") {
+    rnn.state.cell <- mx.symbol.Variable("rnn.state.cell")
+  }
+  
+  cls.weight <- mx.symbol.Variable("cls.weight")
+  cls.bias <- mx.symbol.Variable("cls.bias")
+  
+  embed <- mx.symbol.Embedding(data=data, input_dim=input.size,
+                               weight=embed.weight, output_dim=num.embed, 
name="embed")
+  
+  # RNN cells
+  if (cell.type == "lstm") {
+    rnn <- mx.symbol.RNN(data=embed, state=rnn.state, state_cell = 
rnn.state.cell, parameters=rnn.params.weight, state.size=num.hidden, 
num.layers=num.rnn.layer, bidirectional=F, mode=cell.type, 
state.outputs=output_last_state, p=dropout, name=paste(cell.type, 
num.rnn.layer, "layer", sep="_"))
+    
+  } else {
+    rnn <- mx.symbol.RNN(data=embed, state=rnn.state, 
parameters=rnn.params.weight, state.size=num.hidden, num.layers=num.rnn.layer, 
bidirectional=F, mode=cell.type, state.outputs=output_last_state, p=dropout, 
name=paste(cell.type, num.rnn.layer, "layer", sep="_"))
+  }
+  
+  # Decode
+  if (config=="seq-to-one") {
+    
+    if (masking) mask <- mx.symbol.SequenceLast(data=rnn[[1]], 
use.sequence.length = T, sequence_length = seq.mask, name = "mask") else
+      mask <- mx.symbol.SequenceLast(data=rnn[[1]], use.sequence.length = F, 
name = "mask")
+    
+    fc <- mx.symbol.FullyConnected(data=mask,
+                                   weight=cls.weight,
+                                   bias=cls.bias,
+                                   num.hidden=num.label,
+                                   name = "decode")
+    
+    loss <- mx.symbol.SoftmaxOutput(data=fc, label=label, use_ignore = 
!ignore_label == -1, ignore_label = ignore_label, name = "loss")
+    
+  } else if (config=="one-to-one"){
+    
+    if (masking) mask <- mx.symbol.SequenceMask(data = rnn[[1]], 
use.sequence.length = T, sequence_length = seq.mask, value = 0, name = "mask") 
else
+      mask <- mx.symbol.identity(data = rnn[[1]], name = "mask")
+
+    reshape = mx.symbol.reshape(mask, shape=c(num.hidden, -1))
+    
+    decode <- mx.symbol.FullyConnected(data=reshape,
+                                       weight=cls.weight,
+                                       bias=cls.bias,
+                                       num.hidden=num.label,
+                                       name = "decode")
+    
+    label <- mx.symbol.reshape(data=label, shape=c(-1), name = "label_reshape")
+    loss <- mx.symbol.SoftmaxOutput(data=decode, label=label, use_ignore = 
!ignore_label == -1, ignore_label = ignore_label, name = "loss")
+    
+  }
+  return(loss)
+}
+
+
+# LSTM cell symbol
+lstm.cell <- function(num.hidden, indata, prev.state, param, seqidx, layeridx, 
dropout = 0) {
+  i2h <- mx.symbol.FullyConnected(data = indata, weight = param$i2h.weight, 
bias = param$i2h.bias, 
+                                  num.hidden = num.hidden * 4, name = 
paste0("t", seqidx, ".l", layeridx, ".i2h"))
+  
+  if (dropout > 0) 
+    i2h <- mx.symbol.Dropout(data = i2h, p = dropout)
+  
+  if (!is.null(prev.state)) {
+    h2h <- mx.symbol.FullyConnected(data = prev.state$h, weight = 
param$h2h.weight, 
+                                    bias = param$h2h.bias, num.hidden = 
num.hidden * 4, 
+                                    name = paste0("t", seqidx, ".l", layeridx, 
".h2h"))
+    gates <- i2h + h2h
+  } else {
+    gates <- i2h
+  }
+  
+  split.gates <- mx.symbol.split(gates, num.outputs = 4, axis = 1, 
squeeze.axis = F, 
+                                 name = paste0("t", seqidx, ".l", layeridx, 
".slice"))
+  
+  in.gate <- mx.symbol.Activation(split.gates[[1]], act.type = "sigmoid")
+  in.transform <- mx.symbol.Activation(split.gates[[2]], act.type = "tanh")
+  forget.gate <- mx.symbol.Activation(split.gates[[3]], act.type = "sigmoid")
+  out.gate <- mx.symbol.Activation(split.gates[[4]], act.type = "sigmoid")
+  
+  if (is.null(prev.state)) {
+    next.c <- in.gate * in.transform
+  } else {
+    next.c <- (forget.gate * prev.state$c) + (in.gate * in.transform)
+  }
+  
+  next.h <- out.gate * mx.symbol.Activation(next.c, act.type = "tanh")
+  
+  return(list(c = next.c, h = next.h))
+}
+
+# GRU cell symbol
+gru.cell <- function(num.hidden, indata, prev.state, param, seqidx, layeridx, 
dropout = 0) {
+  i2h <- mx.symbol.FullyConnected(data = indata, weight = 
param$gates.i2h.weight, 
+                                  bias = param$gates.i2h.bias, num.hidden = 
num.hidden * 2, 
+                                  name = paste0("t", seqidx, ".l", layeridx, 
".gates.i2h"))
+  
+  if (dropout > 0) 
+    i2h <- mx.symbol.Dropout(data = i2h, p = dropout)
+  
+  if (!is.null(prev.state)) {
+    h2h <- mx.symbol.FullyConnected(data = prev.state$h, weight = 
param$gates.h2h.weight, 
+                                    bias = param$gates.h2h.bias, num.hidden = 
num.hidden * 2, 
+                                    name = paste0("t", seqidx, ".l", layeridx, 
".gates.h2h"))
+    gates <- i2h + h2h
+  } else {
+    gates <- i2h
+  }
+  
+  split.gates <- mx.symbol.split(gates, num.outputs = 2, axis = 1, 
squeeze.axis = F, 
+                                 name = paste0("t", seqidx, ".l", layeridx, 
".split"))
+  
+  update.gate <- mx.symbol.Activation(split.gates[[1]], act.type = "sigmoid")
+  reset.gate <- mx.symbol.Activation(split.gates[[2]], act.type = "sigmoid")
+  
+  htrans.i2h <- mx.symbol.FullyConnected(data = indata, weight = 
param$trans.i2h.weight, 
+                                         bias = param$trans.i2h.bias, 
num.hidden = num.hidden, 
+                                         name = paste0("t", seqidx, ".l", 
layeridx, ".trans.i2h"))
+  
+  if (is.null(prev.state)) {
+    h.after.reset <- reset.gate * 0
+  } else {
+    h.after.reset <- prev.state$h * reset.gate
+  }
+  
+  htrans.h2h <- mx.symbol.FullyConnected(data = h.after.reset, weight = 
param$trans.h2h.weight, 
+                                         bias = param$trans.h2h.bias, 
num.hidden = num.hidden, 
+                                         name = paste0("t", seqidx, ".l", 
layeridx, ".trans.h2h"))
+  
+  h.trans <- htrans.i2h + htrans.h2h
+  h.trans.active <- mx.symbol.Activation(h.trans, act.type = "tanh")
+  
+  if (is.null(prev.state)) {
+    next.h <- update.gate * h.trans.active
+  } else {
+    next.h <- prev.state$h + update.gate * (h.trans.active - prev.state$h)
+  }
+  
+  return(list(h = next.h))
+}
+
+# 
+#' unroll representation of RNN running on non CUDA device - under development
+#' 
+#' @export
+rnn.unroll <- function(num.rnn.layer, 
+                       seq.len, 
+                       input.size,
+                       num.embed, 
+                       num.hidden,
+                       num.label,
+                       dropout,
+                       ignore_label,
+                       init.state=NULL,
+                       config,
+                       cell.type="lstm", 
+                       masking = F, 
+                       output_last_state=F) {
+  
+  embed.weight <- mx.symbol.Variable("embed.weight")
+  cls.weight <- mx.symbol.Variable("cls.weight")
+  cls.bias <- mx.symbol.Variable("cls.bias")
+  
+  param.cells <- lapply(1:num.rnn.layer, function(i) {
+    
+    if (cell.type=="lstm"){
+      cell <- list(i2h.weight = mx.symbol.Variable(paste0("l", i, 
".i2h.weight")),
+                   i2h.bias = mx.symbol.Variable(paste0("l", i, ".i2h.bias")),
+                   h2h.weight = mx.symbol.Variable(paste0("l", i, 
".h2h.weight")),
+                   h2h.bias = mx.symbol.Variable(paste0("l", i, ".h2h.bias")))
+    } else if (cell.type=="gru"){
+      cell <- list(gates.i2h.weight = mx.symbol.Variable(paste0("l", i, 
".gates.i2h.weight")),
+                   gates.i2h.bias = mx.symbol.Variable(paste0("l", i, 
".gates.i2h.bias")),
+                   gates.h2h.weight = mx.symbol.Variable(paste0("l", i, 
".gates.h2h.weight")),
+                   gates.h2h.bias = mx.symbol.Variable(paste0("l", i, 
".gates.h2h.bias")),
+                   trans.i2h.weight = mx.symbol.Variable(paste0("l", i, 
".trans.i2h.weight")),
+                   trans.i2h.bias = mx.symbol.Variable(paste0("l", i, 
".trans.i2h.bias")),
+                   trans.h2h.weight = mx.symbol.Variable(paste0("l", i, 
".trans.h2h.weight")),
+                   trans.h2h.bias = mx.symbol.Variable(paste0("l", i, 
".trans.h2h.bias")))
+    }
+    return (cell)
+  })
+  
+  # embeding layer
+  label <- mx.symbol.Variable("label")
+  data <- mx.symbol.Variable("data")
+  
+  embed <- mx.symbol.Embedding(data = data, input_dim = input.size,
+                               weight=embed.weight, output_dim = num.embed, 
name = "embed")
+  
+  embed <- mx.symbol.split(data = embed, axis = 0, num.outputs = seq.len, 
squeeze_axis = T)
+  
+  last.hidden <- list()
+  last.states <- list()
+  
+  for (seqidx in 1:seq.len) {
+    hidden <- embed[[seqidx]]
+    
+    for (i in 1:num.rnn.layer) {
+      
+      if (seqidx==1) prev.state<- init.state[[i]] else prev.state <- 
last.states[[i]]
+      
+      if (cell.type=="lstm") {
+        cell.symbol <- lstm.cell
+      } else if (cell.type=="gru"){
+        cell.symbol <- gru.cell
+      }
+      
+      next.state <- cell.symbol(num.hidden = num.hidden, 
+                                indata = hidden,
+                                prev.state = prev.state,
+                                param = param.cells[[i]],
+                                seqidx = seqidx, 
+                                layeridx = i,
+                                dropout = dropout)
+      hidden <- next.state$h
+      last.states[[i]] <- next.state
+    }
+    
+    # Decoding
+    if (config=="one-to-one"){
+      last.hidden <- c(last.hidden, hidden)
+    }
+  }
+  
+  if (config=="seq-to-one"){
+    fc <- mx.symbol.FullyConnected(data = hidden,
+                                   weight = cls.weight,
+                                   bias = cls.bias,
+                                   num.hidden = num.label)
+    
+    loss <- mx.symbol.SoftmaxOutput(data = fc, name="sm", label=label, 
use_ignore = !ignore_label == -1, ignore_label = ignore_label)
+    
+  } else if (config=="one-to-one"){
+    
+    # concat hidden units - concat seq.len blocks of dimension num.hidden x 
batch.size
+    concat <- mx.symbol.concat(data = last.hidden, num.args = seq.len, dim = 
0, name = "concat")
+    
+    decode <- mx.symbol.FullyConnected(data = concat,
+                                       weight = cls.weight,
+                                       bias = cls.bias,
+                                       num.hidden = num.label,
+                                       name = "decode")
+    
+    label <- mx.symbol.reshape(data = label, shape = -1, name = 
"label_reshape")
+    loss <- mx.symbol.SoftmaxOutput(data = decode, name="sm", label = label, 
use_ignore = !ignore_label == -1, ignore_label = ignore_label)
+    
+  }
+  return(loss)
+}
diff --git a/R-package/R/rnn.infer.R b/R-package/R/rnn.infer.R
new file mode 100644
index 0000000..c9ccecb
--- /dev/null
+++ b/R-package/R/rnn.infer.R
@@ -0,0 +1,177 @@
+# 
+#' Inference of RNN model
+#'
+#' @param infer.data Data iterator created by mx.io.bucket.iter
+#' @param model Model used for inference
+#' @param ctx The element to mask
+#'
+#' @export
+mx.infer.buckets <- function(infer.data, model, ctx = mx.cpu()) {
+  
+  ### Initialise the iterator
+  infer.data$reset()
+  infer.data$iter.next()
+  
+  if (is.null(ctx)) 
+    ctx <- mx.ctx.default()
+  if (is.mx.context(ctx)) {
+    ctx <- list(ctx)
+  }
+  if (!is.list(ctx)) 
+    stop("ctx must be mx.context or list of mx.context")
+  
+  ndevice <- length(ctx)
+  symbol <- model$symbol
+  if (is.list(symbol)) sym_ini <- symbol[[names(train.data$bucketID)]] else 
sym_ini <- symbol
+  
+  arguments <- sym_ini$arguments
+  input.names <- intersect(names(infer.data$value()), arguments)
+  
+  input.shape <- sapply(input.names, function(n) {
+    dim(infer.data$value()[[n]])
+  }, simplify = FALSE)
+  
+  shapes <- sym_ini$infer.shape(input.shape)
+  
+  # initialize all arguments with zeros
+  arguments.ini <- lapply(shapes$arg.shapes, function(shape) {
+    mx.nd.zeros(shape = shape, ctx = mx.cpu())
+  })
+
+  arg.params <- model$arg.params
+  arg.params.names <- names(arg.params)
+  aux.params <- model$aux.params
+  
+  # Initial binding
+  dlist <- arguments.ini[input.names]
+  
+  # Assign fixed parameters to their value and keep non initialized arguments 
to zero
+  arg.params.fix.names <- setdiff(arguments, c(arg.params.names, input.names))
+  
+  # Assign zeros to non initialized arg parameters
+  arg.params.fix <- arguments.ini[arg.params.fix.names]
+  
+  # Grad request
+  grad.req <- rep("null", length(arguments))
+  
+  # Arg array order
+  update_names <- c(input.names, arg.params.fix.names, arg.params.names)
+  arg_update_idx <- match(arguments, update_names)
+  
+  execs <- mx.symbol.bind(symbol = symbol, arg.arrays = c(dlist, 
arg.params.fix, arg.params)[arg_update_idx], 
+                                  aux.arrays = aux.params, ctx = ctx[[1]], 
grad.req = grad.req)
+  
+  # Initial input shapes - need to be adapted for multi-devices - divide 
highest
+  # dimension by device nb
+  
+  packer <- mx.nd.arraypacker()
+  infer.data$reset()
+  while (infer.data$iter.next()) {
+    
+    # Get input data slice
+    dlist <- infer.data$value()  #[input.names]
+    
+    execs <- mx.symbol.bind(symbol = symbol, arg.arrays = c(dlist, 
execs$arg.arrays[arg.params.fix.names], 
execs$arg.arrays[arg.params.names])[arg_update_idx], 
+                                    aux.arrays = execs$aux.arrays, ctx = 
ctx[[1]], grad.req = grad.req)
+    
+    mx.exec.forward(execs, is.train = FALSE)
+    
+    out.pred <- mx.nd.copyto(execs$ref.outputs[[1]], mx.cpu())
+    padded <- infer.data$num.pad()
+    oshape <- dim(out.pred)
+    ndim <- length(oshape)
+    packer$push(mx.nd.slice.axis(data = out.pred, axis = 0, begin = 0, end = 
oshape[[ndim]] - padded))
+    
+  }
+  infer.data$reset()
+  return(packer$get())
+}
+
+
+
+### inference for one-to-one models
+mx.infer.buckets.one <- function(infer.data, 
+                                 symbol, arg.params, aux.params, input.params 
= NULL, 
+                                 ctx = mx.cpu()) {
+  
+  ### Initialise the iterator
+  infer.data$reset()
+  infer.data$iter.next()
+  
+  if (is.null(ctx)) 
+    ctx <- mx.ctx.default()
+  if (is.mx.context(ctx)) {
+    ctx <- list(ctx)
+  }
+  if (!is.list(ctx)) 
+    stop("ctx must be mx.context or list of mx.context")
+  
+  ndevice <- length(ctx)
+  
+  arguments <- symbol$arguments
+  input.names <- intersect(names(infer.data$value()), arguments)
+  
+  input.shape <- sapply(input.names, function(n) {
+    dim(infer.data$value()[[n]])
+  }, simplify = FALSE)
+  
+  shapes <- symbol$infer.shape(input.shape)
+  
+  # initialize all arguments with zeros
+  arguments.ini <- lapply(shapes$arg.shapes, function(shape) {
+    mx.nd.zeros(shape = shape, ctx = mx.cpu())
+  })
+  
+  arg.params <- arg.params
+  arg.params.names <- names(arg.params)
+  
+  dlist <- arguments.ini[input.names]
+  
+  # Assign fixed parameters to their value and keep non initialized arguments 
to zero
+  arg.params.fix.names <- unique(c(names(input.params), setdiff(arguments, 
c(arg.params.names, input.names))))
+  
+  # Assign zeros to non initialized arg parameters
+  arg.params.fix <- arguments.ini[arg.params.fix.names]
+  # Assign weights to arguments specifies by input.params
+  arg.params.fix[names(input.params)] <- input.params
+  
+  aux.params <- aux.params
+  
+  # Grad request
+  grad.req <- rep("null", length(arguments))
+  
+  # Arg array order
+  update_names <- c(input.names, arg.params.fix.names, arg.params.names)
+  arg_update_idx <- match(arguments, update_names)
+  
+  # Initial binding
+  execs <- mx.symbol.bind(symbol = symbol, 
+                                  arg.arrays = c(dlist, arg.params.fix, 
arg.params)[arg_update_idx], 
+                                  aux.arrays = aux.params, ctx = ctx[[1]], 
grad.req = grad.req)
+  
+  # Initial input shapes - need to be adapted for multi-devices - divide 
highest
+  # dimension by device nb
+  
+  infer.data$reset()
+  while (infer.data$iter.next()) {
+    
+    # Get input data slice
+    dlist <- infer.data$value()[input.names]
+    
+    execs <- mx.symbol.bind(symbol = symbol, 
+                                    arg.arrays = c(dlist, 
execs$arg.arrays[arg.params.fix.names], 
execs$arg.arrays[arg.params.names])[arg_update_idx],
+                                    aux.arrays = execs$aux.arrays, ctx = 
ctx[[1]], grad.req = grad.req)
+    
+    mx.exec.forward(execs, is.train = FALSE)
+    
+    out.pred <- mx.nd.copyto(execs$ref.outputs[[1]], mx.cpu())
+    state <- mx.nd.copyto(execs$ref.outputs[[2]], mx.cpu())
+    state_cell <- mx.nd.copyto(execs$ref.outputs[[3]], mx.cpu())
+    
+    out <- lapply(execs$ref.outputs, function(out) {
+      mx.nd.copyto(out, mx.cpu())
+    })
+  }
+  infer.data$reset()
+  return(out)
+}
diff --git a/R-package/R/rnn_model.R b/R-package/R/rnn_model.R
deleted file mode 100644
index aa4a7d0..0000000
--- a/R-package/R/rnn_model.R
+++ /dev/null
@@ -1,258 +0,0 @@
-is.param.name <- function(name) {
-    return (grepl('weight$', name) || grepl('bias$', name) ||
-           grepl('gamma$', name) || grepl('beta$', name) )
-}
-
-# Initialize the data iter
-mx.model.init.iter.rnn <- function(X, y, batch.size, is.train) {
-  if (is.mx.dataiter(X)) return(X)
-  shape <- dim(X)
-  if (is.null(shape)) {
-    num.data <- length(X)
-  } else {
-    ndim <- length(shape)
-    num.data <- shape[[ndim]]
-  }
-  if (is.null(y)) {
-    if (is.train) stop("Need to provide parameter y for training with R 
arrays.")
-    y <- c(1:num.data) * 0
-  }
-
-  batch.size <- min(num.data, batch.size)
-
-  return(mx.io.arrayiter(X, y, batch.size=batch.size, shuffle=is.train))
-}
-
-# set up rnn model with rnn cells
-setup.rnn.model <- function(rnn.sym, ctx,
-                            num.rnn.layer, seq.len,
-                            num.hidden, num.embed, num.label,
-                            batch.size, input.size,
-                            init.states.name,
-                            initializer=mx.init.uniform(0.01),
-                            dropout=0) {
-
-    arg.names <- rnn.sym$arguments
-    input.shapes <- list()
-    for (name in arg.names) {
-        if (name %in% init.states.name) {
-            input.shapes[[name]] <- c(num.hidden, batch.size)
-        }
-        else if (grepl('data$', name) || grepl('label$', name) ) {
-            if (seq.len == 1) {
-                input.shapes[[name]] <- c(batch.size)
-            } else {
-            input.shapes[[name]] <- c(seq.len, batch.size)
-            }
-        }
-    }
-    params <- mx.model.init.params(rnn.sym, input.shapes, NULL, initializer, 
mx.cpu())
-    args <- input.shapes
-    args$symbol <- rnn.sym
-    args$ctx <- ctx
-    args$grad.req <- "write"
-    rnn.exec <- do.call(mx.simple.bind, args)
-
-    mx.exec.update.arg.arrays(rnn.exec, params$arg.params, match.name=TRUE)
-    mx.exec.update.aux.arrays(rnn.exec, params$aux.params, match.name=TRUE)
-
-    grad.arrays <- list()
-    for (name in names(rnn.exec$ref.grad.arrays)) {
-        if (is.param.name(name))
-            grad.arrays[[name]] <- rnn.exec$ref.arg.arrays[[name]]*0
-    }
-    mx.exec.update.grad.arrays(rnn.exec, grad.arrays, match.name=TRUE)
-
-    return (list(rnn.exec=rnn.exec, symbol=rnn.sym,
-                 num.rnn.layer=num.rnn.layer, num.hidden=num.hidden,
-                 seq.len=seq.len, batch.size=batch.size,
-                 num.embed=num.embed))
-
-}
-
-
-calc.nll <- function(seq.label.probs, batch.size) {
-    nll = - sum(log(seq.label.probs)) / batch.size
-    return (nll)
-}
-
-get.label <- function(label, ctx) {
-    label <- as.array(label)
-    seq.len <- dim(label)[[1]]
-    batch.size <- dim(label)[[2]]
-    sm.label <- array(0, dim=c(seq.len*batch.size))
-    for (seqidx in 1:seq.len) {
-        sm.label[((seqidx-1)*batch.size+1) : (seqidx*batch.size)] <- 
label[seqidx,]
-    }
-    return (mx.nd.array(sm.label, ctx))
-}
-
-
-# training rnn model
-train.rnn <- function (model, train.data, eval.data,
-                       num.round, update.period,
-                       init.states.name,
-                       optimizer='sgd', ctx=mx.ctx.default(), 
-                       epoch.end.callback,
-                       batch.end.callback,
-                       verbose=TRUE,
-                       ...) {
-    m <- model
-    
-    model <- list(symbol=model$symbol, 
arg.params=model$rnn.exec$ref.arg.arrays,
-                  aux.params=model$rnn.exec$ref.aux.arrays)
-    
-    seq.len <- m$seq.len
-    batch.size <- m$batch.size
-    num.rnn.layer <- m$num.rnn.layer
-    num.hidden <- m$num.hidden
-
-    opt <- mx.opt.create(optimizer, rescale.grad=(1/batch.size), ...)
-
-    updater <- mx.opt.get.updater(opt, m$rnn.exec$ref.arg.arrays)
-    epoch.counter <- 0
-    log.period <- max(as.integer(1000 / seq.len), 1)
-    last.perp <- 10000000.0
-
-    for (iteration in 1:num.round) {
-        nbatch <- 0
-        train.nll <- 0
-        # reset states
-        init.states <- list()
-        for (name in init.states.name) {
-            init.states[[name]] <- m$rnn.exec$ref.arg.arrays[[name]]*0
-        }
-
-        mx.exec.update.arg.arrays(m$rnn.exec, init.states, match.name=TRUE)
-
-        tic <- Sys.time()
-
-        train.data$reset()
-
-        while (train.data$iter.next()) {
-            # set rnn input
-            rnn.input <- train.data$value()
-            mx.exec.update.arg.arrays(m$rnn.exec, rnn.input, match.name=TRUE)
-
-            mx.exec.forward(m$rnn.exec, is.train=TRUE)
-            seq.label.probs <- 
mx.nd.choose.element.0index(m$rnn.exec$ref.outputs[["sm_output"]], 
get.label(m$rnn.exec$ref.arg.arrays[["label"]], ctx))
-
-            mx.exec.backward(m$rnn.exec)
-            init.states <- list()
-            for (name in init.states.name) {
-                init.states[[name]] <- m$rnn.exec$ref.arg.arrays[[name]]*0
-            }
-
-            mx.exec.update.arg.arrays(m$rnn.exec, init.states, match.name=TRUE)
-            # update epoch counter
-            epoch.counter <- epoch.counter + 1
-            if (epoch.counter %% update.period == 0) {
-                # the gradient of initial c and inital h should be zero
-                init.grad <- list()
-                for (name in init.states.name) {
-                    init.grad[[name]] <- m$rnn.exec$ref.arg.arrays[[name]]*0
-                }
-
-                mx.exec.update.grad.arrays(m$rnn.exec, init.grad, 
match.name=TRUE)
-
-                arg.blocks <- updater(m$rnn.exec$ref.arg.arrays, 
m$rnn.exec$ref.grad.arrays)
-
-                mx.exec.update.arg.arrays(m$rnn.exec, arg.blocks, 
skip.null=TRUE)
-
-                grad.arrays <- list()
-                for (name in names(m$rnn.exec$ref.grad.arrays)) {
-                    if (is.param.name(name))
-                        grad.arrays[[name]] <- 
m$rnn.exec$ref.grad.arrays[[name]]*0
-                }
-                mx.exec.update.grad.arrays(m$rnn.exec, grad.arrays, 
match.name=TRUE)
-
-            }
-
-            train.nll <- train.nll + calc.nll(as.array(seq.label.probs), 
batch.size)
-
-            nbatch <- nbatch + seq.len
-            
-            if (!is.null(batch.end.callback)) {
-              batch.end.callback(iteration, nbatch, environment())
-            }
-            
-            if ((epoch.counter %% log.period) == 0) {
-                message(paste0("Epoch [", epoch.counter,
-                           "] Train: NLL=", train.nll / nbatch,
-                           ", Perp=", exp(train.nll / nbatch)))
-            }
-        }
-        train.data$reset()
-        # end of training loop
-        toc <- Sys.time()
-        message(paste0("Iter [", iteration,
-                   "] Train: Time: ", as.numeric(toc - tic, units="secs"),
-                   " sec, NLL=", train.nll / nbatch,
-                   ", Perp=", exp(train.nll / nbatch)))
-
-        if (!is.null(eval.data)) {
-            val.nll <- 0.0
-            # validation set, reset states
-            init.states <- list()
-            for (name in init.states.name) {
-                init.states[[name]] <- m$rnn.exec$ref.arg.arrays[[name]]*0
-            }
-            mx.exec.update.arg.arrays(m$rnn.exec, init.states, match.name=TRUE)
-
-            eval.data$reset()
-            nbatch <- 0
-            while (eval.data$iter.next()) {
-                # set rnn input
-                rnn.input <- eval.data$value()
-                mx.exec.update.arg.arrays(m$rnn.exec, rnn.input, 
match.name=TRUE)
-                mx.exec.forward(m$rnn.exec, is.train=FALSE)
-                # probability of each label class, used to evaluate nll
-                seq.label.probs <- 
mx.nd.choose.element.0index(m$rnn.exec$ref.outputs[["sm_output"]], 
get.label(m$rnn.exec$ref.arg.arrays[["label"]], ctx))
-                # transfer the states
-                init.states <- list()
-                for (name in init.states.name) {
-                    init.states[[name]] <- m$rnn.exec$ref.arg.arrays[[name]]*0
-                }
-                mx.exec.update.arg.arrays(m$rnn.exec, init.states, 
match.name=TRUE)
-                val.nll <- val.nll + calc.nll(as.array(seq.label.probs), 
batch.size)
-                nbatch <- nbatch + seq.len
-            }
-            eval.data$reset()
-            perp <- exp(val.nll / nbatch)
-            message(paste0("Iter [", iteration,
-                       "] Val: NLL=", val.nll / nbatch,
-                       ", Perp=", exp(val.nll / nbatch)))
-        }
-        # get the model out
-
-
-        epoch_continue <- TRUE
-        if (!is.null(epoch.end.callback)) {
-          epoch_continue <- epoch.end.callback(iteration, 0, environment(), 
verbose = verbose)
-        }
-        
-        if (!epoch_continue) {
-          break
-        }
-    }
-
-    return (m)
-}
-
-# check data and translate data into iterator if data is array/matrix
-check.data <- function(data, batch.size, is.train) {
-    if (!is.null(data) && !is.list(data) && !is.mx.dataiter(data)) {
-        stop("The dataset should be either a mx.io.DataIter or a R list")
-    }
-    if (is.list(data)) {
-        if (is.null(data$data) || is.null(data$label)){
-            stop("Please provide dataset as list(data=R.array, label=R.array)")
-        }
-    data <- mx.model.init.iter.rnn(data$data, data$label, 
batch.size=batch.size, is.train = is.train)
-    }
-    if (!is.null(data) && !data$iter.next()) {
-        data$reset()
-        if (!data$iter.next()) stop("Empty input")
-    }
-    return (data)
-}
diff --git a/R-package/R/viz.graph.R b/R-package/R/viz.graph.R
index 7d0365b..aef90ad 100644
--- a/R-package/R/viz.graph.R
+++ b/R-package/R/viz.graph.R
@@ -45,6 +45,7 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", 
type="graph", graph.wi
       "MAERegressionOutput"=,
       "SVMOutput"=,
       "LogisticRegressionOutput"=,
+      "MakeLoss"=,
       "SoftmaxOutput" = "#b3de69",
       "#fccde5" # default value
     )
@@ -145,9 +146,6 @@ graph.viz <- function(symbol, shape=NULL, direction="TD", 
type="graph", graph.wi
   } else {
     graph_render<- render_graph(graph = graph, output = "graph", width = 
graph.width.px, height = graph.height.px)
   }
-
-  # graph <-visNetwork(nodes = nodes_df, edges = edges_df, main = graph.title) 
%>%
-  #   visHierarchicalLayout(direction = "UD", sortMethod = "directed")
   
   return(graph_render)
 }
diff --git a/R-package/tests/testthat/test_lstm.R 
b/R-package/tests/testthat/test_lstm.R
deleted file mode 100644
index 4a5cdbe..0000000
--- a/R-package/tests/testthat/test_lstm.R
+++ /dev/null
@@ -1,57 +0,0 @@
-require(mxnet)
-
-if (Sys.getenv("R_GPU_ENABLE") != "" & as.integer(Sys.getenv("R_GPU_ENABLE")) 
== 1) {
-  mx.ctx.default(new = mx.gpu())
-  message("Using GPU for testing.")
-}
-
-context("lstm models")
-
-get.nll <- function(s) {
-    pat <- ".*\\NLL=(.+), Perp=.*"
-    nll <- sub(pat, "\\1", s)
-    return (as.numeric(nll))
-} 
-
-test_that("training error decreasing", {
-
-    # Set basic network parameters.
-    batch.size = 2
-    seq.len = 2
-    num.hidden = 1
-    num.embed = 2
-    num.lstm.layer = 2
-    num.round = 5
-    learning.rate= 0.1
-    wd=0.00001
-    clip_gradient=1
-    update.period = 1
-    vocab=17
-
-    X.train <- list(data=array(1:16, dim=c(2,8)), label=array(2:17, 
dim=c(2,8)))
-
-    s <- capture.output(model <- mx.lstm( X.train, 
-                                          ctx=mx.ctx.default(),
-                                          num.round=num.round, 
-                                          update.period=update.period,
-                                          num.lstm.layer=num.lstm.layer, 
-                                          seq.len=seq.len,
-                                          num.hidden=num.hidden, 
-                                          num.embed=num.embed, 
-                                          num.label=vocab,
-                                          batch.size=batch.size, 
-                                          input.size=vocab,
-                                          initializer=mx.init.uniform(0.01), 
-                                          learning.rate=learning.rate,
-                                          wd=wd,
-                                          clip_gradient=clip_gradient))
-
-    prev.nll <- 10000000.0
-    for (r in s) {
-        nll <- get.nll(r)
-        expect_true(prev.nll >= nll)
-        prev.nll <- nll
-
-    }
-
-})
\ No newline at end of file
diff --git a/example/rnn/bucket_R/data_preprocessing_seq_to_one.R 
b/example/rnn/bucket_R/data_preprocessing_seq_to_one.R
new file mode 100644
index 0000000..11c0a0c
--- /dev/null
+++ b/example/rnn/bucket_R/data_preprocessing_seq_to_one.R
@@ -0,0 +1,176 @@
+# download the IMDB dataset
+if (!file.exists("data/aclImdb_v1.tar.gz")) {
+  
download.file("http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";, 
+                "data/aclImdb_v1.tar.gz")
+  untar("data/aclImdb_v1.tar.gz")
+}
+
+# install required packages
+list.of.packages <- c("readr", "dplyr", "stringr", "stringi")
+new.packages <- list.of.packages[!(list.of.packages %in% 
installed.packages()[, "Package"])]
+if (length(new.packages)) install.packages(new.packages)
+
+require("readr")
+require("dplyr")
+require("stringr")
+require("stringi")
+
+negative_train_list <- list.files("data/aclImdb/train/neg/", full.names = T)
+positive_train_list <- list.files("data/aclImdb/train/pos/", full.names = T)
+
+negative_test_list <- list.files("data/aclImdb/test/neg/", full.names = T)
+positive_test_list <- list.files("data/aclImdb/test/pos/", full.names = T)
+
+file_import <- function(file_list) {
+  import <- sapply(file_list, read_file)
+  return(import)
+}
+
+negative_train_raw <- file_import(negative_train_list)
+positive_train_raw <- file_import(positive_train_list)
+
+negative_test_raw <- file_import(negative_test_list)
+positive_test_raw <- file_import(positive_test_list)
+
+train_raw <- c(negative_train_raw, positive_train_raw)
+test_raw <- c(negative_test_raw, positive_test_raw)
+
+# Pre-process a corpus composed of a vector of sequences Build a dictionnary
+# removing too rare words
+text_pre_process <- function(corpus, count_threshold = 10, dic = NULL) {
+  raw_vec <- corpus
+  raw_vec <- stri_enc_toascii(str = raw_vec)
+  
+  ### perform some preprocessing
+  raw_vec <- str_replace_all(string = raw_vec, pattern = "[^[:print:]]", 
replacement = "")
+  raw_vec <- str_to_lower(string = raw_vec)
+  raw_vec <- str_replace_all(string = raw_vec, pattern = "_", replacement = " 
")
+  raw_vec <- str_replace_all(string = raw_vec, pattern = "\\bbr\\b", 
replacement = "")
+  raw_vec <- str_replace_all(string = raw_vec, pattern = "\\s+", replacement = 
" ")
+  raw_vec <- str_trim(string = raw_vec)
+  
+  ### Split raw sequence vectors into lists of word vectors (one list element 
per
+  ### sequence)
+  word_vec_list <- stri_split_boundaries(raw_vec, type = "word", 
skip_word_none = T, 
+    skip_word_number = F, simplify = F)
+  
+  ### Build vocabulary
+  if (is.null(dic)) {
+    word_vec_unlist <- unlist(word_vec_list)
+    word_vec_table <- sort(table(word_vec_unlist), decreasing = T)
+    word_cutoff <- which.max(word_vec_table < count_threshold)
+    word_keep <- names(word_vec_table)[1:(word_cutoff - 1)]
+    stopwords <- c(letters, "an", "the", "br")
+    word_keep <- setdiff(word_keep, stopwords)
+  } else word_keep <- names(dic)[!dic == 0]
+  
+  ### Clean the sentences to keep only the curated list of words
+  word_vec_list <- lapply(word_vec_list, function(x) x[x %in% word_keep])
+  
+  # sentence_vec<- stri_split_boundaries(raw_vec, type='sentence', simplify = 
T)
+  word_vec_length <- lapply(word_vec_list, length) %>% unlist()
+  
+  ### Build dictionnary
+  dic <- 1:length(word_keep)
+  names(dic) <- word_keep
+  dic <- c(`ยค` = 0, dic)
+  
+  ### reverse dictionnary
+  rev_dic <- names(dic)
+  names(rev_dic) <- dic
+  
+  return(list(word_vec_list = word_vec_list, dic = dic, rev_dic = rev_dic))
+}
+
+################################################################ 
+make_bucket_data <- function(word_vec_list, labels, dic, seq_len = c(225), 
right_pad = T) {
+  ### Trunc sequence to max bucket length
+  word_vec_list <- lapply(word_vec_list, head, n = max(seq_len))
+  
+  word_vec_length <- lapply(word_vec_list, length) %>% unlist()
+  bucketID <- cut(word_vec_length, breaks = c(0, seq_len, Inf), include.lowest 
= T, 
+    labels = F)
+  
+  ### Right or Left side Padding Pad sequences to their bucket length with
+  ### dictionnary 0-label
+  word_vec_list_pad <- lapply(1:length(word_vec_list), function(x) {
+    length(word_vec_list[[x]]) <- seq_len[bucketID[x]]
+    word_vec_list[[x]][is.na(word_vec_list[[x]])] <- names(dic[1])
+    if (right_pad == F) 
+      word_vec_list[[x]] <- rev(word_vec_list[[x]])
+    return(word_vec_list[[x]])
+  })
+  
+  ### Assign sequences to buckets and unroll them in order to be reshaped into 
arrays
+  unrolled_arrays <- lapply(1:length(seq_len), function(x) 
unlist(word_vec_list_pad[bucketID == 
+    x]))
+  
+  ### Assign labels to their buckets
+  bucketed_labels <- lapply(1:length(seq_len), function(x) labels[bucketID == 
x])
+  names(bucketed_labels) <- as.character(seq_len)
+  
+  ### Assign the dictionnary to each bucket terms
+  unrolled_arrays_dic <- lapply(1:length(seq_len), function(x) 
dic[unrolled_arrays[[x]]])
+  
+  # Reshape into arrays having each sequence into a row
+  features <- lapply(1:length(seq_len), function(x) {
+    t(array(unrolled_arrays_dic[[x]], 
+          dim = c(seq_len[x], length(unrolled_arrays_dic[[x]])/seq_len[x])))
+  })
+  
+  names(features) <- as.character(seq_len)
+  
+  ### Combine data and labels into buckets
+  buckets <- lapply(1:length(seq_len), function(x) c(list(data = 
features[[x]]), 
+    list(label = bucketed_labels[[x]])))
+  names(buckets) <- as.character(seq_len)
+  
+  ### reverse dictionnary
+  rev_dic <- names(dic)
+  names(rev_dic) <- dic
+  
+  return(list(buckets = buckets, dic = dic, rev_dic = rev_dic))
+}
+
+
+corpus_preprocessed_train <- text_pre_process(corpus = train_raw, 
count_threshold = 10, 
+  dic = NULL)
+
+corpus_preprocessed_test <- text_pre_process(corpus = test_raw, dic = 
corpus_preprocessed_train$dic)
+
+seq_length_dist <- unlist(lapply(corpus_preprocessed_train$word_vec_list, 
length))
+quantile(seq_length_dist, 0:20/20)
+
+
+# Save bucketed corpus
+corpus_bucketed_train <- make_bucket_data(word_vec_list = 
corpus_preprocessed_train$word_vec_list, 
+                                          labels = rep(0:1, each = 12500), 
+                                          dic = corpus_preprocessed_train$dic, 
+                                          seq_len = c(100, 150, 250, 400, 
600), 
+                                          right_pad = T)
+
+corpus_bucketed_test <- make_bucket_data(word_vec_list = 
corpus_preprocessed_test$word_vec_list, 
+                                         labels = rep(0:1, each = 12500), 
+                                         dic = corpus_preprocessed_test$dic, 
+                                         seq_len = c(100, 150, 250, 400, 600), 
+                                         right_pad = T)
+
+saveRDS(corpus_bucketed_train, file = "data/corpus_bucketed_train.rds")
+saveRDS(corpus_bucketed_test, file = "data/corpus_bucketed_test.rds")
+
+
+# Save non bucketed corpus
+corpus_single_train <- make_bucket_data(word_vec_list = 
corpus_preprocessed_train$word_vec_list, 
+                                          labels = rep(0:1, each = 12500), 
+                                          dic = corpus_preprocessed_train$dic, 
+                                          seq_len = c(600), 
+                                          right_pad = T)
+
+corpus_single_test <- make_bucket_data(word_vec_list = 
corpus_preprocessed_test$word_vec_list, 
+                                         labels = rep(0:1, each = 12500), 
+                                         dic = corpus_preprocessed_test$dic, 
+                                         seq_len = c(600), 
+                                         right_pad = T)
+
+saveRDS(corpus_single_train, file = "data/corpus_single_train.rds")
+saveRDS(corpus_single_test, file = "data/corpus_single_test.rds")

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to