dcslin commented on pull request #724:
URL: https://github.com/apache/singa/pull/724#issuecomment-647273232
Hi @joddiy , thank you for the code. testing this for lstm model, but it
seems that the `model.compile` could not pass.
``` python
#!/usr/bin/env python
# coding: utf-8
# In[2]:
import sys
build_path = r'/root/singa-imdb/build/python'
sys.path.append(build_path)
from singa import autograd
from singa import layer
from singa import model
from singa import tensor
from singa import device
from singa import opt
import numpy as np
bs = 32
seq_limit = 50
embed_size = 300
hid = 64
max_epoch=20
vocab_size=100
# In[3]:
class IMDB(model.Model):
def __init__(self, hidden_size, seq_length):
super().__init__()
batch_first = True
self.em = layer.Embedding(vocab_size, embed_size)
self.l1 = layer.Linear(64)
self.l2 = layer.Linear(2)
def forward(self, x):
y = self.em(x)
y = autograd.reshape(y, (y.shape[0], -1))
y = self.l1(y)
y = autograd.relu(y)
y = self.l2(y)
return y
def train_one_batch(self, x, y):
out = self.forward(x)
loss = autograd.softmax_cross_entropy(out, y)
self.optimizer(loss)
return out, loss
def set_opt(self, optimizer):
self.optimizer = optimizer
# In[ ]:
dev = device.create_cuda_gpu_on(7)
x = np.random.randint(0, vocab_size, (bs, seq_limit))
tx = tensor.from_numpy(x)
tx.to_device(dev)
ty = tensor.Tensor((bs, 2), dev, tensor.float32)
ty.gaussian(0,1)
m = IMDB(hid, seq_limit)
m.set_opt(opt.SGD())
m.compile([tx], is_train=True, use_graph=False, sequential=False)
# In[1]:
"""
WARNING: Logging before InitGoogleLogging() is written to STDERR
F0622 04:00:20.654505 388 common.cc:34] Check failed: initialized_ Must
initialize data before reading it
*** Check failure stack trace: ***
"""
# In[3]:
# out, loss = m(tx, ty)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]