#!/usr/bin/python

import copy
import pprint
import traceback
import sys
import operator

class Stack(object):
    def __init__(self):
        self.data = []

    def empty(self):
        return len(self.data) == 0

    def push(self, ob):
        self.data.append(ob)

    def pop(self):
        return self.data.pop()

    def top(self):
        return self.data[-1]

class Interpreter(object):
    def __init__(self):
        self.operator = Namespace()
        self.definition = Namespace()
        self.specialnames = set(["operator_call"])
    
    def get_special(self, frame, name):
        if name == "operator": return self.operator
        if name == "definition": return self.definition
        if name == "outer": return frame.context.parent
        if name == "context": return frame.context
        raise Exception("No such special:" + name)

class Exits(object):
    def __init__(self, returne, yielde, raisee):
        self.freturn = returne
        self.fyield = yielde
        self.fraise = raisee

    def returns_to(self, frame):
        new = Exits(self.freturn, self.fyield, self.fraise)
        new.freturn = frame
        return new

    def raises_to(self, frame):
        new = Exits(self.freturn, self.fyield, self.fraise)
        new.fraise = frame
        return new

    def yields_to(self, frame):
        new = Exits(self.freturn, self.fyield, self.fraise)
        new.fyield = frame
        return new

ApNone = object()

#########################################
# Op Codes
#########################################

# basic stack operations
def OP_pop(frame):
    frame.valuestack.pop()

def OP_rot(frame):
    a = frame.valuestack.pop()
    b = frame.valuestack.pop()
    frame.valuestack.push(a)
    frame.valuestack.push(b)

def OP_dup(frame):
    a = frame.valuestack.top()
    frame.valuestack.push(a)

# variable access

def OP_loadcontext(frame):
    frame.valuestack.push(frame.context)

def OP_set(name):
    def set_op(frame):
        scope = frame.valuestack.pop()
        val = frame.valuestack.pop()
        scope.set(name, val)
    return set_op

def OP_define(name):
    def define_op(frame):
        val = frame.valuestack.pop()
        scope = frame.valuestack.pop()
        scope.define(name, val)
    return define_op

def OP_loadname(name):
    def loadname_op(frame):
        scope = frame.valuestack.pop()
        frame.valuestack.push(scope.lookup(name))
    return loadname_op

def OP_loadconst(val):
    def loadconst_op(frame):
        frame.valuestack.push(val)
    return loadconst_op

def OP_loadspecial(name):
    def loadspecial_op(frame):
        frame.valuestack.push(frame.interpreter.get_special(frame, name))
    return loadspecial_op

# try block

def OP_try_except(frame):
    exceptcode  = frame.valuestack.pop()
    code        = frame.valuestack.pop()
    handleframe = Frame(exceptcode, frame.context, frame.exits.returns_to(frame))
    newframe = Frame(code, frame.context, frame.exits.returns_to(frame).raises_to(handleframe))
    return newframe

def OP_try_finally(frame):
    finallycode = frame.valuestack.pop()
    code        = frame.valuestack.pop()
    handleframe = Frame(finallycode, frame.context, frame.exits.returns_to(frame))
    newframe = Frame(code, frame.context, Exits(handle_frame, handle_frame, handle_frame))
    return newframe

# control flow transfer

def OP_raise(frame):
    exc = frame.valuestack.pop()
    if frame.exits.fraise:
        frame.exits.fraise.valuestack.push(exc)
    return frame.exits.fraise

def OP_return(frame):
    if frame.exits.freturn:
        frame.exits.freturn.valuestack.push(frame.valuestack.pop())
    return frame.exits.freturn

def OP_yield(frame):
    if frame.exits.fyield:
        frame.exits.fyield.valuestack.push(frame.valuestack.pop())
    return frame.exits.fyield

def OP_call(frame):
    numargs = frame.valuestack.pop()
    args = []
    for i in xrange(numargs):
        args.append(frame.valuestack.pop())
    args.reverse()
    func = frame.valuestack.pop()
    return func.call(frame, args)

def OP_jump(amount):
    def jump_op(frame):
        frame.pos += amount
    return jump_op

def OP_jump_if_false(amount):
    def jump_if_op(frame):
    	val = frame.valuestack.top()
        if not val:
            frame.pos += amount
    return jump_if_op

#########################################
# internal functions
#########################################

def call_object(obj, args, context):
    normal = DummyFrame()
    yielded = DummyFrame()
    raised = DummyFrame()
    frame = obj.call(Frame([], context, Exits(normal, yielded, raised)), args)
    runningframe = frame
    while not isinstance(runningframe, DummyFrame):
        runningframe = runningframe.step()

    if not runningframe.valuestack.empty():
        if runningframe is normal:
            return normal.valuestack.pop(), "RETURN"
        if runningframe is yielded:
            return yielded.valuestack.pop(), "YIELD"
        if runningframe is raised:
            raise raised.valuestack.pop()

    return None, "RETURN"

#########################################
# object system
#########################################

class Object(object):
	def __init__(self, type):
		self.type = type
		self.dict = {}

	def lookup_attr(self, name):
		return self.dict.get(name, None)
		
	def lookup(self, name):
		var = self.lookup_attr(name)
		if var:
			return var
		elif self.type:
			return self.type.lookup(name).get_attr(self, name)
		else:
		 	raise Exception("Attribute '%s' not found on %s" % (name, self.name))

	def set_attr(self, name, val):
		if name in self.dict:
			self.dict[name] = val	
			return True
		return False

	def set(self, name, val):
		if self.set_attr(name, val):
			return
		elif self.type:
			self.type.lookup(name).set_attr(self, name, val)
		else:
		 	raise Exception("Attribute '%s' not found on %s" % (name, self.name))

	def define(self, name, val):
		self.dict[name] = val
	
	def get_attr(self, inst, name):
		raise Exception("Not a property")
		
	def set_attr(self, inst, name, val):
		raise Exception("Not a property")
	
	def call(self, frame, args):
		raise Exception("Instance not callable")
		
	def __repr__(self):
		return "%s()" % (self.type,)

class Class(Object):
	def __init__(self, name, metatype = None, bases = None, dict = None):
		Object.__init__(self, metatype)
		self.name = name
		self.dict = dict or {}
		self.bases = bases or []
		self.mro = self.calc_mro()
		
	def dfs_bases(self, list):
		list.append(self)
		for base in self.bases:
			base.dfs_bases(list)
		
	def calc_mro(self):
		mro = []
		for base in self.bases:
			base.dfs_bases(mro)
		newmro = []
		for klass in reversed(mro):
			if klass not in newmro:
				newmro.append(klass)
		newmro.reverse()
		return newmro 
		
	def lookup_attr(self, name):
        	for klass in [self] + self.mro:
            		tmp = klass.dict.get(name, None)
            		if tmp is not None:
            			return tmp
         	return None
		
	def set_attr(self, name, val):
		for klass in [self] + self.mro:
			if name in klass.dict:
				klass.dict[name] = val
				return True
		return False
		
	def call(self, frame, args):
		instance = Object(self)
		frame.valuestack.push(instance)
		return frame
		
	def __repr__(self):
		return self.name


Metaclass = Class("metaclass", None, None, {})

Baseobject = Class("object", Metaclass, None, {})

SelfPropertyClass = Class("self_property", Metaclass, [Baseobject])
class SelfProperty(Object):
    def __init__(self):
    	Object.__init__(self, SelfPropertyClass)
        	
    def get_attr(self, inst, name):
            return inst

Metaclass.dict["operator_call"] = SelfProperty()

BuiltinFunctionClass = Class("builtin_function", Metaclass, [Baseobject])
class BuiltinFunction(Object):
    def __init__(self, args, func):
    	Object.__init__(self, BuiltinFunctionClass)
        self.args = args
        self.func = func

    def call(self, frame, args):
        newcontext = Context(frame.context)
        try:
            retval = self.func(frame, newcontext, *args)
        except Exception, e:
            traceback.print_exc()
            frame.exits.fraise.valuestack.push(e.args[0])
            return frame.exits.fraise
        else:
            frame.valuestack.push(retval)
            return frame
            
    def get_attr(self, inst, name):
        return BoundFunction(self, inst)
	    	
BuiltinFunctionClass.dict["operator_call"] = SelfProperty()
		
BoundFunctionClass = Class("builtin_function", Metaclass, [Baseobject])
class BoundFunction(Object):
	def __init__(self, func, inst):
		Object.__init__(self, BoundFunctionClass)
		self.func = func
		self.inst = inst
		
	def call(self, frame, args):
		args = [self.inst] + args
		return self.func.call(frame, args)

BoundFunctionClass.dict["operator_call"] = SelfProperty()

FunctionClass = Class("function", Metaclass, [Baseobject])
class Function(Object):
    def __init__(self, args, code, context):
	Object.__init__(self, FunctionClass)
        self.args = args
        self.code = code
        self.context = context

    def call(self, frame, args):
        assert len(args) == len(self.args)
        newcontext = Context(self.context)
        for name,arg in zip(self.args, args):
            newcontext.define(name, arg)
        return Frame(self.code, newcontext, frame.exits.returns_to(frame))

    def get_attr(self, inst, name):
    	return BoundFunction(self, inst)

FunctionClass.dict["operator_call"] = SelfProperty()
        
GeneratorClass = Class("generator", Metaclass, [Baseobject])
class Generator(Object):
    def __init__(self, args, code, context):
    	Object.__init__(self, GeneratorClass)
        self.args = args
        self.newcontext = Context(context)
        self.newframe = Frame(code, self.newcontext, frame.exits)
        self.firstcall = True

    def call(self, frame, args):
        if self.firstcall:
            self.firstcall = False
            for name,arg in zip(self.args, args):
                self.newcontext.define(name, arg)
        else:
            assert len(args) == len(self.args)
            self.newframe.valuestack.push(tuple(args))
        self.newframe.exits = self.newframe.exits.returns_to(frame).yields_to(frame)
        return self.newframe

NamespaceClass = Class("namespace", Metaclass, [Baseobject])
class Namespace(Object):
    def __init__(self):
	Object.__init__(self, NamespaceClass)

    def lookup(self, name):
        var = self.dict.get(name, None)
        if var is not None:
            return var
        raise Exception("No such variable: %s" % (name,))

    def set(self, name, val):
        if name in self.dict:
            self.dict[name] = val
        else:
            raise Exception("No such variable: %s" % (name,))

    def define(self, name, val):
        self.dict[name] = val
        
    def __repr__(self):
	return repr(self.dict)

ContextClass = Class("context", Metaclass, [Baseobject])
class Context(Object):
    def __init__(self, parent):
    	Object.__init__(self, ContextClass)
        self.parent = parent

    def lookup(self, name):
        context = self
        while context is not None:
            tmp = context.dict.get(name, None)
            if tmp is not None:
                return tmp
            else:
                context = context.parent
        raise Exception("No such variable: %s" % (name,))

    def set(self, name, val):
        context = self
        while context is not None:
            tmp = context.dict.get(name, None)
            if tmp is not None:
                context.dict[name] = val
                return
            else:
                context = context.parent
        raise Exception("No such variable: %s" % (name,))

    def define(self, name, val):
        self.dict[name] = val

class Stacking(object):
    def __init__(self):
        self.valuestack = Stack()

FrameClass = Class("frame", Metaclass, [Baseobject])
class Frame(Stacking, Object):
    def __init__(self, code, context, exits, debug=False):
        Stacking.__init__(self)
        Object.__init__(self, FrameClass)
        self.context = context
        self.code = code
        self.exits = exits
        self.pos = 0
        self.valuestack = Stack()
        self.debug = debug
        global interpreter
        self.interpreter = interpreter

    def jump(amount):
        self.pos += amount

    def step(self):
        if self.pos >= len(self.code):
            return self.exits.freturn

        command = self.code[self.pos]
        frame = command(self)
        if self.debug:
            print command, self.context.dict
        self.pos += 1
        if frame is not None:
            return frame
        else:
            return self

    def details(self):
        return self.pos, len(self.code), self.exits

    def call(self, frame, args):
        assert len(args) == 0
        self.exits = frame.exits.returns_to(frame)
        return self

class DummyFrame(Stacking):
    pass

#########################################
# ast objects
#########################################

NodeClass = Class("node", Metaclass, [Baseobject])
def node_render(frame, context, self):
	code = []
	self.render(code)
	return code
	
def node_type(frame, context, self):
	return self.__class__.__name__
	
NodeClass.dict["render"] = BuiltinFunction(("self", ),
        	                                 node_render)
NodeClass.dict["type"] = BuiltinFunction(("self", ),
        	                                 node_type)


class Node(Object):
    def __init__(self):
    	Object.__init__(self, NodeClass)

    def render(self, code):
        pass
        

class Suite(Node):
    def __init__(self, children):
    	Node.__init__(self)
        self.children = children

    def append(self, node):
        self.children.append(node)

    def render(self, code):
        for child in self.children:
            child.render(code)

class SuiteDeclaration(Node):
    def __init__(self, type, ast, suite):
    	Node.__init__(self)
        self.type = type
        self.ast = ast
        self.suite = suite

    def render(self, code):
        code.extend([OP_loadspecial("definition"),
                     OP_loadname(self.type)])
        self.ast.render(code)
        code.extend([OP_loadconst(self.suite),
                     OP_loadconst(2),
                     OP_call,
                     OP_pop])

class Return(Node):
    def __init__(self, expr):
    	Node.__init__(self)
        self.expr = expr

    def render(self, code):
        self.expr.render(code)
        code.append(OP_return)

class Yield(Node):
    def __init__(self, expr):
    	Node.__init__(self)
        self.expr = expr

    def render(self, code):
        self.expr.render(code)
        code.append(OP_yield)

class Raise(Node):
    def __init__(self, expr):
    	Node.__init__(self)
        self.expr = expr

    def render(self, code):
        self.expr.render(code)
        code.append(OP_raise)

class Define(Node):
    def __init__(self, name, child):
        Node.__init__(self)
        self.name = name
        self.child = child
    
    def render(self, code):
        code.append(OP_loadcontext)
        self.child.render(code)
        code.extend([OP_define(self.name)])

class Set(Node):
    def __init__(self, name, child):
        Node.__init__(self)
        self.name = name
        self.child = child
    
    def render(self, code):
        self.child.render(code)
        code.extend([OP_dup, OP_loadcontext, OP_set(self.name)])

class Call(Node):
    def __init__(self, child, args):
        Node.__init__(self)
        self.child = child
        self.args = args

    def render(self, code):
        self.child.render(code);
        for arg in self.args:
            arg.render(code)
        code.extend([OP_loadconst(len(self.args)),
                     OP_call])

class BinaryOperator(Node):
    ops = {"+"  : "add",
           "-"  : "subtract",
           "*"  : "multiply",
           "/"  : "divide",
           "<"  : "less_than",
           ">"  : "greater_than",
           "==" : "equal_to",
           "!=" : "not_equal_to",}
     
    def __init__(self, op, lhs, rhs):
        Node.__init__(self)
        self.op = self.ops[op]
        self.lhs = lhs
        self.rhs = rhs

    def render(self, code):
        code.extend([OP_loadspecial("operator"),
                     OP_loadname(self.op)])
        self.lhs.render(code)
        self.rhs.render(code)
        code.extend([OP_loadconst(2),
                     OP_call])

class As(Node):
	def __init__(self, lhs, rhs):
	    	Node.__init__(self)
	        self.lhs = lhs
	        self.rhs = rhs
	
	def render(self, code):
		code.extend([OP_loadspecial("operator"),
	                     OP_loadname("as")])
	        self.lhs.render(code)
	        self.rhs.render(code)
	        code.extend([OP_loadconst(2),
	                     OP_call])

class In(Node):
	def __init__(self, lhs, rhs):
	    	Node.__init__(self)
	        self.lhs = lhs
	        self.rhs = rhs
	
	def render(self, code):
		code.extend([OP_loadspecial("operator"),
	                     OP_loadname("in")])
	        self.lhs.render(code)
	        self.rhs.render(code)
	        code.extend([OP_loadconst(2),
	                     OP_call])


class ExprStmt(Node):
    def __init__(self, expr):
        Node.__init__(self)
        self.expr = expr

    def render(self, code):
        self.expr.render(code)
        code.append(OP_pop)

class Const(Node):
    def __init__(self, val):
        Node.__init__(self)
        self.val = val

    def render(self, code):
        code.extend([OP_loadconst(self.val)])

class VarLookup(Node):
    def __init__(self, name):
        Node.__init__(self)
        self.name = name

    def render(self, code):
        code.extend([OP_loadcontext,
                     OP_loadname(self.name)])

class SpecialLookup(Node):
    def __init__(self, name):
        Node.__init__(self)
        self.name = name

    def render(self, code):
        code.extend([OP_loadspecial(self.name)])

class AttrLookup(Node):
    def __init__(self, name, expr):
        Node.__init__(self)
        self.expr = expr
        self.name = name

    def render(self, code):
        self.expr.render(code)
        code.append(OP_loadname(self.name))

class DefineAttr(Node):
    def __init__(self, name, objexpr, valexpr):
        Node.__init__(self)
        self.name = name
        self.objexpr = objexpr
        self.valexpr = valexpr

    def render(self, code):
        self.objexpr.render(code)
        self.valexpr.render(code)
        code.append(OP_define(self.name))

class SetAttr(Node):
    def __init__(self, name, objexpr, valexpr):
        Node.__init__(self)
        self.name = name
        self.objexpr = objexpr
        self.valexpr = valexpr

    def render(self, code):
        self.valexpr.render(code)
        code.append(OP_dup)
        self.objexpr.render(code)
        code.append(OP_set(self.name))

class If(Node):
    def __init__(self, testexpr, truesuite, elsesuite):
        Node.__init__(self)
    	assert isinstance(testexpr, Const)
        self.testexpr = testexpr.val
        self.truesuite = truesuite
        self.elsesuite = elsesuite

    def render(self, code):
    	testcode = []
        self.testexpr.render(testcode)
        # enable this at some point to disallow assignments in if statements
        #assert OP_set not in testcode
        code.extend(testcode)
        truecode = [OP_pop]
        self.truesuite.render(truecode)
        elsecode = [OP_pop]
        if self.elsesuite:
        	self.elsesuite.render(elsecode)
        	truecode.append(OP_jump(len(elsecode)))
        code.append(OP_jump_if_false(len(truecode)))
        code.extend(truecode)
        code.extend(elsecode)

class While(Node):
    def __init__(self, testexpr, suite):
        Node.__init__(self)
        assert isinstance(testexpr, Const)
        self.testexpr = testexpr.val
        self.suite = suite

    def render(self, code):
        exprcode = []
        self.testexpr.render(exprcode)
        loopcode = [OP_pop]
        self.suite.render(loopcode)
        loopcode.append(OP_jump(-(len(loopcode) + len(exprcode) + 2)))
        code.extend(exprcode)
        code.append(OP_jump_if_false(len(loopcode)))
        code.extend(loopcode + [OP_pop])

#########################################

def print_func(frame, context, arg):
    print arg

def type_func(frame, context, arg):
	return arg.type

def define_def(frame, context, ast, suite):
    if isinstance(ast, As):
    	ast = ast.lhs
    assert isinstance(ast, Call)
    name = ast.child.name
    args = []
    for x in ast.args:
    	if isinstance(x, As):
    		args.append(x.lhs.name)
    	else:
    		args.append(x.name)
    suite.append(Return(Const(ApNone)))
    code = []
    suite.render(code)
    func = Function(args, code, context.parent)
    context.parent.define(name, func)

def define_gen(frame, context, ast, suite):
    if isinstance(ast, As):
	ast = ast.lhs
    assert isinstance(ast, Call)
    name = ast.child.name
    args = []
    for x in ast.args:
    	if isinstance(x, As):
    		args.append(x.lhs.name)
    	else:
    		args.append(x.name)
    suite.append(While(Const(True), Suite([ExprStmt(Yield(Const(ApNone)))])))
    code = []
    suite.render(code)
    func = Generator(args, code, context.parent)
    context.parent.define(name, func)

def define_class(frame, context, ast, suite):
    if isinstance(ast, As):
	ast = ast.lhs
    assert isinstance(ast, Call)
    name = ast.child.name
    args = []
    for x in ast.args:
    	if isinstance(x, As):
    		args.append(x.lhs.name)
    	else:
    		args.append(x.name)
    code = []
    suite.render(code)
    callcontext = Context(context.parent)
    tmpframe = Frame(code, callcontext, {})
    call_object(tmpframe, (), context.parent)
    klass = Class(name, Metaclass, None, callcontext.dict)
    context.parent.define(name, klass)

def make_operator(op):
    def func(frame, context, lhs, rhs):
        return op(lhs, rhs)
    return func

def null_binary(frame, context, lhs, rhs):
	return lhs

#########################################

interpreter = None

def setup_interpreter():
        interpreter.definition.dict["def"] = BuiltinFunction(("ast", "suite"),
                                         define_def)
        interpreter.definition.dict["gen"] = BuiltinFunction(("ast", "suite"),
                                         define_gen)
        interpreter.definition.dict["class"] = BuiltinFunction(("ast", "suite"),
                                         define_class)
        
        interpreter.operator.dict["add"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.add))
        interpreter.operator.dict["subtract"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.sub))
        interpreter.operator.dict["multiply"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.mul))
	interpreter.operator.dict["divide"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.div))
        interpreter.operator.dict["less_than"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.lt))
        interpreter.operator.dict["greater_than"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.gt))
        interpreter.operator.dict["equal_to"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.eq))
        interpreter.operator.dict["not_equal_to"] = BuiltinFunction(("lhs", "rhs"),
                                         make_operator(operator.ne))
        interpreter.operator.dict["as"] = BuiltinFunction(("lhs", "rhs"),
                                         null_binary)
        interpreter.operator.dict["in"] = BuiltinFunction(("lhs", "rhs"),
                                         null_binary)

def initial_context():
	con = Context(None)
	con.dict["print"] = BuiltinFunction(("arg",), print_func)
	con.dict["type"] = BuiltinFunction(("arg",), type_func)
	return con

interpreter = Interpreter()
setup_interpreter()
