On Wed, 2021-09-29 at 15:09 -0400, Aaron Watters wrote:
> Hi folks!
> 
> The np.choose function raises a ValueError if called with more than
> 31
> choices.
> This PR adds an alternate implementation np.extended_choose (which
> uses the
> base implementation) that supports any number of choices.
> 
> https://github.com/numpy/numpy/pull/20001
> 
> FYI, I needed this functionality for a mouse embryo microscopy tool
> I'm
> building.
> I'm attempting to contribute it because I thought it might be
> generally
> useful.


Thanks for the effort of upstreaming your code!  My inclination is
against adding it, though.  The limitation of `choose` to 32 argument
is unfortunate, but adding a new function as a workaround does not seem
great, either.
Maybe it would be possible to fix `choose` instead?  Unfortunately, it
seems likely that the current `choose` code is great for few choices
but bad for many, so that might require switching between different
strategies.

Importantly:  If your data (choices) is an array and not a sequence of
arrays, you should use `np.take_along_axis` instead, which is far
superior!
For small to mid-sized arrays, it may even be fastest to use it with
`np.asarray(choices)`, because it avoids many overheads.


Not happy with the idea of extending the way choose works to many
choices, I cooked up the approach below.  My expectation is that it
should be much faster for many choices, at least for larger arrays.
The approach below moves the work of the element to pick from which
choice into a (fairly involved) pre-processing step to make the final
assignment more streamlined.

Cheers,

Sebastian


```
from itertools import chain


def choose(a, choices):
    # Make sure we work with the correct result shape.
    # (this is not great if `a` ends up being broadcast)
    a_bc, *choices = np.broadcast_arrays(a, *choices)

    a = a_bc.ravel()
    sorter = np.argsort(a, axis=None)
    which = a[sorter]

    indices = np.meshgrid(*[np.arange(s) for s in a_bc.shape])
    indices = [i.ravel()[sorter] for i in indices]

    out_dtype = np.result_type(*choices)
    
    result = np.empty(choices[0].shape, dtype=out_dtype)

    mask = np.empty(which.shape, dtype=bool)

    ends = np.flatnonzero(which[1:] != which[:-1])
    start = 0
    for end in chain(ends, [len(which)]):
        end += 1
        choice = choices[which[start]]
        ind = tuple(i[start:end] for i in indices)
        result[ind] = choice[ind]

        start = end
    
    return result
```


> 
> All comments, complaints, or suggestions or code reviews appreciated.
> 
>    thanks! -- Aaron Watters
> _______________________________________________
> NumPy-Discussion mailing list -- numpy-discussion@python.org
> To unsubscribe send an email to numpy-discussion-le...@python.org
> https://mail.python.org/mailman3/lists/numpy-discussion.python.org/
> Member address: sebast...@sipsolutions.net

Attachment: signature.asc
Description: This is a digitally signed message part

_______________________________________________
NumPy-Discussion mailing list -- numpy-discussion@python.org
To unsubscribe send an email to numpy-discussion-le...@python.org
https://mail.python.org/mailman3/lists/numpy-discussion.python.org/
Member address: arch...@mail-archive.com

Reply via email to