Author: Matti Picus <[email protected]>
Branch: numpy_broadcast_nd
Changeset: r84070:96c7090938b3
Date: 2016-04-30 23:12 +0300
http://bitbucket.org/pypy/pypy/changeset/96c7090938b3/

Log:    try to treat W_Broadcast as true W_NumpyObjects

diff --git a/pypy/module/micronumpy/base.py b/pypy/module/micronumpy/base.py
--- a/pypy/module/micronumpy/base.py
+++ b/pypy/module/micronumpy/base.py
@@ -169,6 +169,6 @@
 
 def convert_to_array(space, w_obj):
     from pypy.module.micronumpy.ctors import array
-    if isinstance(w_obj, W_NDimArray):
+    if isinstance(w_obj, W_NumpyObject) and not w_obj.is_scalar():
         return w_obj
     return array(space, w_obj)
diff --git a/pypy/module/micronumpy/broadcast.py 
b/pypy/module/micronumpy/broadcast.py
--- a/pypy/module/micronumpy/broadcast.py
+++ b/pypy/module/micronumpy/broadcast.py
@@ -42,6 +42,21 @@
 
         self.done = False
 
+    def get_shape(self):
+        return self.shape
+
+    def get_order(self):
+        return self.order
+
+    def get_dtype(self):
+        return self.seq[0].get_dtype() #XXX Fixme
+
+    def get_size(self):
+        return 0  #XXX Fixme
+
+    def create_iter(self, shape=None, backward_broadcast=False):
+        return self, self.list_iter_state # XXX Fixme
+
     def descr_iter(self, space):
         return space.wrap(self)
 
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -1,7 +1,7 @@
 from pypy.interpreter.error import OperationError, oefmt
 from rpython.rlib import jit
 from pypy.module.micronumpy import support, constants as NPY
-from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.base import W_NDimArray, W_NumpyObject
 
 
 # structures to describe slicing
@@ -218,7 +218,7 @@
 def shape_agreement(space, shape1, w_arr2, broadcast_down=True):
     if w_arr2 is None:
         return shape1
-    assert isinstance(w_arr2, W_NDimArray)
+    assert isinstance(w_arr2, W_NumpyObject)
     shape2 = w_arr2.get_shape()
     ret = _shape_agreement(shape1, shape2)
     if len(ret) < max(len(shape1), len(shape2)):
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to