# enumre.py
# (C) 2010 Gabriel Genellina

"""
Enumerate the language defined by a regular expression. In other words,
generate (lazily) all strings that match a given regular expression.

This module provides only the building blocks (merge, prod, repeat, closure)
to generate the matching strings.  Going from a regular expression to the
adequate sequence of function calls must be done manually or using a parser
(see invRegexInf.py for a pyparsing based example).

This table summarizes the required conversion:

a|b     becomes merge(G(a), G(b))
ab      becomes prod(G(a), G(b))
a?      becomes repeat(G(a), 0, 1)
a*      becomes repeat(G(a), 0, infinite) == closure(G(a))
a+      becomes repeat(G(a), 1, infinite)
a{m,n}  becomes repeat(G(a), m, n)

where G(x) represents an iterator yielding all possible values for x.

By example, to generate all strings matching this expression "(a|bc)*d",
one should evaluate:

    prod(
      closure(
        merge(
          'a',
           prod('b','c'))),
      'd')
      
Note that unbounded repeaters like *, +, {m,} may yield an infinite number
of strings.

Example: Enumerating the first 20 matches for "(a|bc)*d":

>>> import re
>>> g = prod(closure(merge('a', prod('b', 'c'))), 'd')
>>> for i,s in enumerate(g):
...     print(s)
...     assert re.match("(a|bc)*d", s)
...     if i>=20: break
d
ad
aad
bcd
aaad
...
bcabcd
bcbcad
aaaaaad
>>>

Skip the first 10000 matches and show the next 5:

>>> g = prod(closure(merge('a', prod('b', 'c'))), 'd')
>>> for s in islice(g, 10000, 10005):
...     print(s)
...     assert re.match("(a|bc)*d", s)
bcabcaaaabcabcaaaad
bcabcaaaabcabcaabcd
bcabcaaaabcabcabcad
bcabcaaaabcabcbcaad
bcabcaaaabcabcbcbcd
>>>

Based on http://www.cs.utexas.edu/users/misra/Notes.dir/RegExp.pdf

Copyright (c) 2010 Gabriel A. Genellina

"""

import sys
from itertools import tee, chain, islice, groupby
from operator import itemgetter
from mergeinf import imerge
from heapq import merge as hqmerge

# compatibility stuff

try:
    xrange
except NameError:
    xrange = range  # >=3.0

try:
    almost_infinite = sys.maxint-1
except AttributeError:
    almost_infinite = sys.maxsize-1 # >=3.0

try:
    next
except NameError:
    next = lambda x: x.next() # <2.6

try:
    from itertools import imap, izip
except ImportError:
    # >=3.0
    imap = map
    izip = zip

try:
    reduce
except NameError:
    from functools import reduce # >=3.0

__all__ = 'merge merge_unsorted prod prod2 repeat closure'.split()


def _imerge_bylen(iterables):
    """A specialized version of mergeinf.imerge:
     - output is sorted first by length, then by value
     - duplicates are removed

    >>> g = _imerge_bylen(((x for x in ['','d','bb']), ['a','ff']))
    >>> list(g)
    ['', 'a', 'd', 'bb', 'ff']
    """
    # decorate
    decorated = (izip(imap(len, it1), it2) for it1, it2 in (tee(it) for it in iterables))
    # merge and undecorate
    result = imap(itemgetter(1), imerge(decorated))
    # eliminate duplicates
    return imap(itemgetter(0), groupby(result))


def _hqmerge_bylen(*iterables):
    """This variant uses heapq.merge with a known number of
    arguments. Each iterable must come internally sorted, but
    they don't have to be sorted among them.
    """
    # decorate
    decorated = [izip(imap(len, it1), it2) for it1, it2 in (tee(it) for it in iterables)]
    # merge and undecorate
    result = imap(itemgetter(1), hqmerge(*decorated))
    # eliminate duplicates
    return imap(itemgetter(0), groupby(result))


merge_unsorted = _hqmerge_bylen


def merge(*iterables):
    """A convenient variant of _imerge_bylen, easier to call with
    a known number of iterables. Like mergeinf.imerge, but:
     - output is sorted first by length, then by value
     - duplicates are removed

    >>> g = merge((x for x in ['','d','bb']), ['a','ff'])
    >>> list(g)
    ['', 'a', 'd', 'bb', 'ff']
    """
    return _imerge_bylen(iterables)


def prod2(xs, ys):
    """A generator that computes the product of two (possibly infinite)
    iterators; for each unique pair (a,b) in the product, a+b is generated
    (string concatenation).
    Provided that input is ordered, output is ordered too, following
    the same rules as mergeinf.imerge.

    >>> g = prod2((x for x in ['','c','bb']), (x for x in ['d','ff']))
    >>> list(g)
    ['d', 'cd', 'ff', 'bbd', 'cff', 'bbff']
    """
    try:
        x = next(xs)
        y = next(ys)
    except StopIteration:
        return
    ys1, ys = tee(ys)
    xs1, xs = tee(xs)

    # Below, x+y is guaranteed to be the minimum value,
    # and the prod2 recursive call is guaranteed to
    # yield greater values than all others.
    # But we cannot guarantee that the first generator expression
    # will always yield a value not greater than the second one,
    # so we cannot use mergeinf.imerge; using heapq.merge instead.
    # (This could be reworked.)

    yield x+y # it's important to yield *before* the recursive call
    for z in _hqmerge_bylen(
          ((x + y1) for y1 in ys1),
          ((x1 + y) for x1 in xs1),
          prod2(xs, ys)
      ):
        yield z


def prod(*args):
    """Same as prod2 but takes any number of arguments,
    and they may be any kind of iterable. Left-associative.

    >>> g = prod(['d','ff'], ['', 'c','bb'])
    >>> list(g)
    ['d', 'dc', 'ff', 'dbb', 'ffc', 'ffbb']
    >>> g = prod(['d','ff'], ['', 'c','bb'], ['x','y'])
    >>> list(g)[-1]
    'ffbby'
    """
    return reduce(prod2, (iter(xs) for xs in args))


def _powers(xs, minr=1, maxr=-1):
    """Generate successive iterators, each one representing
    xs**r for r ranging between minr and maxr inclusive.
    xs may be any iterable.
    xs**2 is prod(xs, xs), xs**3 is prod(xs, xs, xs), and so on.

    >>> x3 = list(_powers(['a','b'], 3, 3))
    >>> len(x3)==1
    True
    >>> list(x3[0]) == list(prod(['a','b'], ['a','b'], ['a','b']))
    True
    """

    if maxr < minr: maxr = minr
    include_empty_str = (minr == 0)
    # an empty iterable? xs**n is empty too
    xs, xs2 = tee(xs)
    try:
        x = next(xs)
    except StopIteration:
        return
    if x == '':
        # it the iterable contains the empty string, it will appear in any power
        # otherwise, '' appears only if xs**0 was requested
        result, base = tee(xs)  # drop ''
        include_empty_str = True
    else:
        result, base = tee(xs2)
    del xs, xs2

    if include_empty_str:
        yield iter(('',))

    i = 1
    while True:
        # invariant: result == xs**i
        if minr <= i <= maxr:
            result, result2 = tee(result)
            yield result2
        if i >= maxr:
            break
        base, base2 = tee(base)
        result = prod2(result, base2)
        i += 1


def repeat(xs, minr=1, maxr=-1):
    """Lazily compute all elements from every xs**r
    for r ranging between minr and maxr inclusive (sorted by length).
    xs may be any iterable.
    xs**2 is prod(xs, xs), xs**3 is prod(xs, xs, xs), and so on.

    >>> list(repeat(['a','b'], 2))
    ['aa', 'ab', 'ba', 'bb']
    >>> list(repeat(['a'], 3, 5))
    ['aaa', 'aaaa', 'aaaaa']
    """
    return _imerge_bylen(_powers(xs, minr, maxr))


def closure(xs):
    """The Kleene closure of xs. Another name for repeat(xs, 0, infinite)

    `closure` is easy to write by itself, but since we also need
    `repeat` with a finite range, just write the former in terms of the latter.

    >>> list(islice(closure(['a']), 0, 5))
    ['', 'a', 'aa', 'aaa', 'aaaa']
    """
    return repeat(xs, 0, almost_infinite)


def test():
    import re

    class FakeIterator:
      def __getitem__(self, index):
          if index<100: return 'x'
          assert False, 'should not iterate that far'

    assert list(merge([], [])) == []
    assert list(merge([], ['1', '2', '3'])) == ['1', '2', '3']
    assert list(merge(['1', '2', '3'], [])) == ['1', '2', '3']
    assert list(merge(['1', '2', '3'], ['2'])) == ['1', '2', '3']
    assert list(merge(['1', '2', '3'], ['9'])) == ['1', '2', '3', '9']
    try: list(merge(['9'], ['1', '2', '3']))
    except ValueError: pass
    else: raise AssertionError('merge should fail here!')
    assert list(merge_unsorted(['9'], ['1', '2', '3'])) == ['1', '2', '3', '9']
    assert list(merge(list('134'), list('124789'))) == [
        '1', '2', '3', '4', '7', '8', '9']
    assert list(merge(['a', 'aab'], ['b', 'bb'])) == ['a', 'b', 'bb', 'aab']
    assert list(
                 prod(['','c','bb'], ['d','ff'])
               ) == [
                 'd', 'cd', 'ff', 'bbd', 'cff', 'bbff'
               ]
    assert list(
                 prod2((x for x in ['','c','bb']), (x for x in ['d','ff']))
               ) == [
                 'd', 'cd', 'ff', 'bbd', 'cff', 'bbff'
               ]
    assert list(prod(['', 'a'], ['b', 'ab'], [])) == []
    assert list(prod(['', 'a'], [], FakeIterator())) == []
    assert list(prod(['', 'a'], ['b', 'ab'])) == ['b', 'ab', 'aab']
    assert list(prod(['', 'a', 'b'], ['a', 'b'])) == [
        'a', 'b', 'aa', 'ab', 'ba', 'bb']
    assert list(repeat([], 3)) == []
    assert list(repeat([''], 3)) == ['']
    a04 = ['', 'a', 'aa', 'aaa', 'aaaa']
    assert list(repeat(['a'], 0, 4)) == a04
    assert list(repeat(['', 'a'], 0, 4)) == a04
    assert list(repeat(['', 'a'], 1, 4)) == a04
    assert list(repeat(['a'], 1, 4)) == a04[1:]
    assert list(repeat(['a'], 2, 4)) == a04[2:]
    assert list(repeat(['a'], 3, 4)) == a04[3:]
    assert list(repeat(['a'], 4, 4)) == a04[4:]
    assert list(repeat(['a'], 4)) == a04[4:]
    assert list(repeat(['a', 'b'], 0, 2)) == [
        '', 'a', 'b', 'aa', 'ab', 'ba', 'bb']
    assert sum(1 for _ in repeat(['a', 'b'], 0, 5)) == 63
    r = list(repeat(['a', 'b'], 2, 3))
    assert len(r) == 12
    assert r[:5] == ['aa', 'ab', 'ba', 'bb', 'aaa']
    assert r[-1] == 'bbb'

    g = prod(closure(merge('a', prod('b', 'c'))), 'd')
    g = islice(g, 10000, 10005)
    assert next(g)=='bcabcaaaabcabcaaaad'

    zprev = ''
    for z in repeat(['a', 'bbbbb'], 1, 1000):
      assert len(z)>=len(zprev), (zprev, z)
      assert len(z)>len(zprev) or len(z)==len(zprev) and z>zprev, (zprev, z)
      zprev = z

    import doctest
    doctest.testmod(
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
        verbose=1)


if __name__ == '__main__':
    test()
