GSanchis commented on issue #8669: module.Module and CSVIter URL: https://github.com/apache/incubator-mxnet/issues/8669#issuecomment-350847038 So... diving deep into this, and into some of the python libraries of MXNet, I managed to craft a first version of an iterator which seems to do the job... ``` class myCSVIter(mx.io.DataIter): def __init__(self, data_names, data_shapes, label_names, label_shapes, csvfile, delimiter=',', batch_size=100): self.delimiter=delimiter self.batch_size = batch_size self._provide_data = [ DataDesc('user', (batch_size,), int), DataDesc('item', (batch_size,), int) ] self._provide_label = [DataDesc('softmax_label',(batch_size,),numpy.float32)] self.file = open(csvfile,'r') self.csvreader = csv.reader(self.file, delimiter=delimiter) def __iter__(self): return self def reset(self): self.file.seek(0) self.csvreader = csv.reader(self.file, delimiter=self.delimiter) def __next__(self): return self.next() @property def provide_data(self): return self._provide_data @property def provide_label(self): return self._provide_label def next(self): l=[] c=[] v=[] try: for i in range(self.batch_size): row = [int(a) for a in next(self.csvreader)] l.append(row[0]) c.append(row[1]) v.append(row[2]) data = [mx.nd.array(l), mx.nd.array(c)] label= list([mx.nd.array(v)]) return mx.io.DataBatch(data=list(data), label=label) ``` Although I'm still getting this straight... especially because it seems to work properly on a CPU, but I got an error when trying the code on a GPU. I don't have access to the error now, but I'll continue tomorrow. Also, I still have to handle the case where the data is not a multiple of the batch_size.
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services