On Wed, 2021-09-29 at 15:09 -0400, Aaron Watters wrote: > Hi folks! > > The np.choose function raises a ValueError if called with more than > 31 > choices. > This PR adds an alternate implementation np.extended_choose (which > uses the > base implementation) that supports any number of choices. > > https://github.com/numpy/numpy/pull/20001 > > FYI, I needed this functionality for a mouse embryo microscopy tool > I'm > building. > I'm attempting to contribute it because I thought it might be > generally > useful.
Thanks for the effort of upstreaming your code! My inclination is against adding it, though. The limitation of `choose` to 32 argument is unfortunate, but adding a new function as a workaround does not seem great, either. Maybe it would be possible to fix `choose` instead? Unfortunately, it seems likely that the current `choose` code is great for few choices but bad for many, so that might require switching between different strategies. Importantly: If your data (choices) is an array and not a sequence of arrays, you should use `np.take_along_axis` instead, which is far superior! For small to mid-sized arrays, it may even be fastest to use it with `np.asarray(choices)`, because it avoids many overheads. Not happy with the idea of extending the way choose works to many choices, I cooked up the approach below. My expectation is that it should be much faster for many choices, at least for larger arrays. The approach below moves the work of the element to pick from which choice into a (fairly involved) pre-processing step to make the final assignment more streamlined. Cheers, Sebastian ``` from itertools import chain def choose(a, choices): # Make sure we work with the correct result shape. # (this is not great if `a` ends up being broadcast) a_bc, *choices = np.broadcast_arrays(a, *choices) a = a_bc.ravel() sorter = np.argsort(a, axis=None) which = a[sorter] indices = np.meshgrid(*[np.arange(s) for s in a_bc.shape]) indices = [i.ravel()[sorter] for i in indices] out_dtype = np.result_type(*choices) result = np.empty(choices[0].shape, dtype=out_dtype) mask = np.empty(which.shape, dtype=bool) ends = np.flatnonzero(which[1:] != which[:-1]) start = 0 for end in chain(ends, [len(which)]): end += 1 choice = choices[which[start]] ind = tuple(i[start:end] for i in indices) result[ind] = choice[ind] start = end return result ``` > > All comments, complaints, or suggestions or code reviews appreciated. > > thanks! -- Aaron Watters > _______________________________________________ > NumPy-Discussion mailing list -- numpy-discussion@python.org > To unsubscribe send an email to numpy-discussion-le...@python.org > https://mail.python.org/mailman3/lists/numpy-discussion.python.org/ > Member address: sebast...@sipsolutions.net
signature.asc
Description: This is a digitally signed message part
_______________________________________________ NumPy-Discussion mailing list -- numpy-discussion@python.org To unsubscribe send an email to numpy-discussion-le...@python.org https://mail.python.org/mailman3/lists/numpy-discussion.python.org/ Member address: arch...@mail-archive.com