Author: Alex Gaynor <alex.gay...@gmail.com> Branch: numpy-dtype-refactor Changeset: r49553:f65a5c270040 Date: 2011-11-19 14:44 -0500 http://bitbucket.org/pypy/pypy/changeset/f65a5c270040/
Log: recognize scalars in array's constrcutor when guessing types diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -2,7 +2,7 @@ from pypy.interpreter.error import OperationError, operationerrfmt from pypy.interpreter.gateway import interp2app from pypy.interpreter.typedef import TypeDef, GetSetProperty, interp_attrproperty -from pypy.module.micronumpy import interp_dtype, signature, types +from pypy.module.micronumpy import interp_boxes, interp_dtype, signature, types from pypy.rlib import jit from pypy.rlib.rarithmetic import LONG_BIT from pypy.tool.sourcetools import func_with_new_name @@ -248,6 +248,12 @@ long_dtype = interp_dtype.get_dtype_cache(space).w_longdtype int64_dtype = interp_dtype.get_dtype_cache(space).w_int64dtype + if isinstance(w_obj, interp_boxes.W_GenericBox): + dtype = w_obj.get_dtype(space) + if current_guess is None: + return dtype + return find_binop_result_dtype(dtype, current_guess) + if space.isinstance_w(w_obj, space.w_bool): if current_guess is None or current_guess is bool_dtype: return bool_dtype diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -577,7 +577,7 @@ assert b[i] == 2.5 * a[i] def test_dtype_guessing(self): - from numpypy import array, dtype + from numpypy import array, dtype, float64, int8, bool_ assert array([True]).dtype is dtype(bool) assert array([True, False]).dtype is dtype(bool) @@ -587,6 +587,10 @@ assert array([1.2, True]).dtype is dtype(float) assert array([1.2, 5]).dtype is dtype(float) assert array([]).dtype is dtype(float) + assert array([float64(2)]).dtype is dtype(float) + assert array([int8(3)]).dtype is dtype("int8") + assert array([bool_(True)]).dtype is dtype(bool) + assert array([bool_(True), 3.0]).dtype is dtype(float) def test_comparison(self): import operator _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit