#
# invRegexInf.py
#
# Copyright 2008, Paul McGuire
#
# 2010-02-12: Modified by Gabriel Genellina to be able to handle
# unbounded repetitions like +, * and {m,}
#
# pyparsing script to expand a regular expression into all possible matching strings
# Supports:
# - {n}, {,n} and {m,n} repetitions
# - unbounded repetitions: +, *, {m,}
# - ? optional elements
# - [] character ranges
# - () grouping
# - | alternation
# - \d \w \s character classes
#

__all__ = ["count","invert"]

import sys
from itertools import islice
from enumre import repeat, merge, merge_unsorted, prod
from pyparsing import (Literal, oneOf, printables, ParserElement, Combine,
    SkipTo, operatorPrecedence, ParseFatalException, Word, nums, opAssoc,
    Suppress, ParseResults, srange, Optional)

class CharacterRangeEmitter(object):
    def __init__(self,chars):
        # remove duplicate chars in character range, but preserve original order
        seen = set()
        self.charset = "".join( seen.add(c) or c for c in chars if c not in seen )
    def __repr__(self):
        return '['+self.charset+']'
    def __iter__(self):
        return iter(self.charset)

class OptionalEmitter(object):
    def __init__(self,expr):
        self.expr = expr
    def __iter__(self):
        yield ""
        for s in self.expr:
            yield s

class DotEmitter(object):
    def __iter__(self):
        return iter(printables)

class GroupEmitter(object):
    def __init__(self,exprs,minr=1,maxr=-1):
        self.exprs = exprs
        self.minr = minr
        self.maxr = maxr
    def __iter__(self):
        return repeat(self.exprs, self.minr, self.maxr)

class AlternativeEmitter(object):
    def __init__(self,exprs):
        self.exprs = exprs
    def __iter__(self):
        return merge_unsorted(*self.exprs)

class SequenceEmitter(object):
    def __init__(self,exprs):
        self.exprs = exprs
    def __iter__(self):
        return prod(*self.exprs)

class LiteralEmitter(object):
    def __init__(self,lit):
        self.lit = lit
    def __repr__(self):
        return "Lit:"+self.lit
    def __iter__(self):
        yield self.lit

def handleRange(toks):
    return CharacterRangeEmitter(srange(toks[0]))

def handleRepetition(toks):
    toks=toks[0]
    if toks[1] == "*":
        return GroupEmitter(toks[0], 0, sys.maxint)
    if toks[1] == "+":
        return GroupEmitter(toks[0], 1, sys.maxint)
    if toks[1] == "?":
        return OptionalEmitter(toks[0])
    if "count" in toks:
        return GroupEmitter(toks[0], int(toks.count), int(toks.count))
    if "minCount" in toks or "maxCount" in toks:
        if toks.minCount: mincount = int(toks.minCount)
        else: mincount = 0
        if toks.maxCount: maxcount = int(toks.maxCount)
        else: maxcount = sys.maxint
        return GroupEmitter(toks[0], mincount, maxcount)

def handleLiteral(toks):
    lit = ""
    for t in toks:
        if t[0] == "\\":
            if t[1] == "t":
                lit += '\t'
            else:
                lit += t[1]
        else:
            lit += t
    return LiteralEmitter(lit)

def handleMacro(toks):
    macroChar = toks[0][1]
    if macroChar == "d":
        return CharacterRangeEmitter("0123456789")
    elif macroChar == "w":
        return CharacterRangeEmitter(srange("[A-Za-z0-9_]"))
    elif macroChar == "s":
        return LiteralEmitter(" ")
    else:
        raise ParseFatalException("",0,"unsupported macro character (" + macroChar + ")")

def handleSequence(toks):
    return SequenceEmitter(toks[0])

def handleDot():
    return CharacterRangeEmitter(printables)

def handleAlternative(toks):
    return AlternativeEmitter(toks[0])


_parser = None
def parser():
    global _parser
    if _parser is None:
        ParserElement.setDefaultWhitespaceChars("")
        lbrack,rbrack,lbrace,rbrace,lparen,rparen = map(Literal,"[]{}()")

        reMacro = Combine("\\" + oneOf(list("dws")))
        escapedChar = ~reMacro + Combine("\\" + oneOf(list(printables)))
        reLiteralChar = "".join(c for c in printables if c not in r"\[]{}().*?+|") + " \t"

        reRange = Combine(lbrack + SkipTo(rbrack,ignore=escapedChar) + rbrack)
        reLiteral = ( escapedChar | oneOf(list(reLiteralChar)) )
        reDot = Literal(".")
        repetition = (
            ( lbrace + Word(nums).setResultsName("count") + rbrace ) |
            ( lbrace +
                Optional(Word(nums).setResultsName("minCount")) +
                "," +
                Optional(Word(nums).setResultsName("maxCount")) +
                rbrace ) |
            oneOf(list("*+?"))
            )

        reRange.setParseAction(handleRange)
        reLiteral.setParseAction(handleLiteral)
        reMacro.setParseAction(handleMacro)
        reDot.setParseAction(handleDot)

        reTerm = ( reLiteral | reRange | reMacro | reDot )
        reExpr = operatorPrecedence( reTerm,
            [
            (repetition, 1, opAssoc.LEFT, handleRepetition),
            (None, 2, opAssoc.LEFT, handleSequence),
            (Suppress('|'), 2, opAssoc.LEFT, handleAlternative),
            ]
            )
        _parser = reExpr

    return _parser

def count(gen):
    """Simple function to count the number of elements returned by a generator.
    Warning: if given an infinite generator, it will never return!
    """
    i = 0
    for s in gen:
        i += 1
    return i

def invert(regex):
    """Call this routine as a generator to return all the strings that
       match the input regular expression.
       Warning: if the expression includes any unbounded repetition (*, +, {m,})
       the generator will never stop!
       Strings come in length order (shorter strings first).

           for s in invert("[A-Z]{3}\d{3}"):
               print s
    """
    return parser().parseString(regex)[0]

def main():
    import re
    tests = r"""
    [A-EA]
    [A-D]*
    [A-D]{3}
    X[A-C]{3}Y
    X[A-C]{3}\(
    X\d
    foobar\d\d
    foobar{2}
    foobar{2,9}
    foobar{,9}
    fooba[rz]{2}
    (foobar){2}
    ([01]\d)|(2[0-5])
    ([01]\d\d)|(2[0-4]\d)|(25[0-5])
    [A-C]{1,2}
    [A-C]{0,3}
    [A-C]\s[A-C]\s[A-C]
    [A-C]\s?[A-C][A-C]
    [A-C]\s([A-C][A-C])
    [A-C]\s([A-C][A-C])?
    [A-C]{2}\d{2}
    @|TH[12]
    @(@|TH[12])?
    @(@|TH[12]|AL[12]|SP[123]|TB(1[0-9]?|20?|[3-9]))?
    @(@|TH[12]|AL[12]|SP[123]|TB(1[0-9]?|20?|[3-9])|OH(1[0-9]?|2[0-9]?|30?|[4-9]))?
    (([ECMP]|HA|AK)[SD]|HS)T
    [A-CV]{2}
    A[cglmrstu]|B[aehikr]?|C[adeflmorsu]?|D[bsy]|E[rsu]|F[emr]?|G[ade]|H[efgos]?|I[nr]?|Kr?|L[airu]|M[dgnot]|N[abdeiop]?|Os?|P[abdmortu]?|R[abefghnu]|S[bcegimnr]?|T[abcehilm]|Uu[bhopqst]|U|V|W|Xe|Yb?|Z[nr]
    (a|b)|(x|y)
    (a|b) (x|y)
    (a|bbb)+(iii|jj|k)*(xxx|y)*z
    [ab]{1,3}[cd]{,3}[ef]{2,}
    ([xy]{1,3}z|([m-p]+)){0,3}
    """.split('\n')

    for t in tests:
        t = t.strip()
        if not t: continue
        print '-'*50
        print t
        try:
            if '*' in t or '+' in t or ',}' in t:
                print("seems to contain unbounded repetitions, not counted!")
            else:
                print(count(invert(t)))
            # generate (at most) 5000 strings; only 20 are printed
            for i,s in enumerate(islice(invert(t), 0, 5000)):
                if i<=20: print(s)
                assert re.match(t, s)
        except ParseFatalException,pfe:
            print pfe.msg
            print
            continue
        print


if __name__ == "__main__":
    main()