NamelessLoop edited a comment on issue #19778:
URL:
https://github.com/apache/incubator-mxnet/issues/19778#issuecomment-765465584
The python_code.pyx:
```
import mxnet
cimport numpy
cimport mxnet???
def MxNetTrain(net, model_type, trainer, loss_function, train_input,
train_output, val_input, val_output, test_input,
test_output, epochs, args_save):
# In case of attention training un-pack the Tuple objects
if (model_type == 'seq2seq - attention') or model_type == 'seq2seq -
transformer':
encoder_rnn, attn_decoder = net[0], net[1]
trainer = trainer[1]
# init validation loss as infinity
cdef double best_val
best_val = float("Inf")
# initializing arrays to save parameters through epochs
cdef list learning_rates = []
cdef list validation_losses = []
cdef list test_losses = []
cdef list training_losses = []
with open(args_save, 'w') as fp: # condition to avoid
Jankins/file_directory error
pass
cdef int train_bs = train_input.shape[0]
cdef int val_bs = val_input.shape[0]
cdef int test_bs = test_input.shape[0]
# creating nd.arrayIter (training, validation and testing)
train_iter = mx.io.NDArrayIter(train_input, train_output,
batch_size=train_bs, shuffle=True)
val_iter = mx.io.NDArrayIter(val_input, val_output, batch_size=val_bs,
shuffle=True)
test_iter = mx.io.NDArrayIter(test_input, test_output,
batch_size=test_bs, shuffle=True)
cdef double total_loss = 0.0
cdef double ntotal_training = 0.0
cdef ??? encoder_input
cdef ??? decoder_target
# for epoch in tqdm_notebook(range(epochs), desc='epochs'):
iepoch = int()
for epoch in range(epochs):
train_iter.reset()
val_iter.reset()
test_iter.reset()
for trn_batch in train_iter:
if (model_type == 'seq2seq - attention') or (model_type ==
'seq2seq - transformer'):
encoder_input = trn_batch.data[0].as_in_context(mx.cpu())
decoder_target = trn_batch.label[0].as_in_context(mx.cpu())`
```
Then the setup.py;
```
`from distutils.core import setup
import numpy
from Cython.Build import cythonize
setup(
ext_modules= cythonize("MxRnnBoosted.pyx"),
include_dirs=[numpy.get_include(), mxnet_get_include_path_]
)`
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]