package code.loop.parser;

import code.lang.SystemModule;
import code.lang.TupleModule;
import code.symbols.ConstructorSymbol;
import code.symbols.DataSymbol;
import code.symbols.TypeSymbol;
import code.loader.LoadManager;
import code.table.AmbiguousSymbolException;
import code.table.UndefinedSymbol;
import code.table.MapTable;
import code.table.ModuleTable;
import code.term.Term;

import java.util.Vector;

/**
 * @author jimeng
 * @since August 2, 2003
 */
public class AstTermCreateVisitor implements AstVisitor {

    private static final String PreludeModuleName = "Prelude";
    private static final String SupportModuleName = "$Support";
    private static final String leftSectionName = "leftSection";
    private static final String rightSectionName = "rightSection";

    private static final int leftSectionId
	= ModuleTable.getId(SupportModuleName, leftSectionName);
    private static final int rightSectionId
	= ModuleTable.getId(SupportModuleName, rightSectionName);
    private static final int closureId
	= ModuleTable.getId(SystemModule.moduleName, SystemModule.closureName);
    private static final int applyId
	= ModuleTable.getId(SystemModule.moduleName, SystemModule.applyName);
    private static final int negateId
	= ModuleTable.getId(PreludeModuleName, "-");
    private static final int nilId
	= ModuleTable.getId(PreludeModuleName, "[]");
    private static final int consId = ModuleTable.getId(PreludeModuleName, ":");
    private static final int enumFromId
	= ModuleTable.getId(PreludeModuleName, "enumFrom");
    private static final int enumFromThenId
	= ModuleTable.getId(PreludeModuleName, "enumFromThen");
    private static final int enumFromToId
	= ModuleTable.getId(PreludeModuleName, "enumFromTo") ;
    private static final int enumFromThenToId
	= ModuleTable.getId(PreludeModuleName, "enumFromThenTo");

    private static final int ifthenelseId
	= ModuleTable.getId(PreludeModuleName, "if_then_else");

    private static final int unitId
	= ModuleTable.getId(TupleModule.moduleName, "()");
    private static final int pairId
	= ModuleTable.getId(TupleModule.moduleName, "(,)");

    private int getRoot(String maybeQualifiedName) {
	// TODO: should check that the symbol is public,
	// otherwise throw HiddenSymbol exception
	int root;
	String moduleName;
	String symbolName;
	boolean qualified = maybeQualifiedName.matches("\\w+\\..+");
	if (qualified) {
	    // qualified identifier
	    int index = maybeQualifiedName.indexOf('.');
	    moduleName = maybeQualifiedName.substring(0,index);
	    symbolName = maybeQualifiedName.substring(index+1);
	    root = ModuleTable.sureGetId(moduleName, symbolName);
	} else {
	    // unqualified identifier
	    // first look in the current module
	    // if found, that's it
	    symbolName = maybeQualifiedName;
	    try {
		moduleName = LoadManager.getCurrentModuleName();
		root = ModuleTable.sureGetId(moduleName, symbolName);
	    } catch (UndefinedSymbol ex0) {
		// if not found, look in all the imported modules
		// and ensure there is only one symbol with the given name
		try { 
		    root = ModuleTable.sureGetId("Prelude", symbolName);
		} catch (UndefinedSymbol ex1) {
		    // no symbol with the given name was found
		    // re-throw exception with the original name
		    throw new UndefinedSymbol(maybeQualifiedName);
		}
	    }
	}
	return root;
    }

    public Object visit(Ast.Term term, Object o) {
        return term.expr.accept(this, term.wt.variableList);
    }

    public Object visit(Ast.IfExpr ifExpr, Object o) {
        Term[] tmp = new Term[]{(Term) ifExpr.e1.accept(this, o), (Term) ifExpr.e2.accept(this, o), (Term) ifExpr.e3.accept(this, o)};
        Term term = new Term(ifthenelseId, tmp);
        return term;
    }

    public Object visit(Ast.InfixOperation infixOperation, Object o)
            throws AmbiguousSymbolException {
	// re-routed to FunctionExpr to handle >>=
        // which is infix, takes 3 arguments, and is
	// typically invoked with 2 arguments only
	Vector basics = new Vector(3);
	basics.add(new Ast.PrefixId(infixOperation.infixId));
	basics.add(infixOperation.e1);
	basics.add(infixOperation.e2);
	Ast.FunctionExpr functionExpr = new Ast.FunctionExpr(basics);
	return visit(functionExpr, o);
	/*  simpler code without re-routing
	int root = getRoot(infixOperation.infixId);
        DataSymbol infixSymbol = MapTable.getSymbol(root);
        return new Term(root, new Term[]{(Term) infixOperation.e1.accept(this, o), (Term) infixOperation.e2.accept(this, o)});
	*/
    }

    public Object visit(Ast.FunctionExpr functionExpr, Object o) {
        int argsNumber = functionExpr.basics.size() - 1;
        Term term = (Term) ((Ast.Expr) functionExpr.basics.elementAt(0)).accept(this, o);
        /**
         * it is a function application or partial application,
         * otherwise constructor, variable, function without argument
         * or partial function without any application now
         */
        if (argsNumber > 0) {
            /**
             * check if it is already fully applied, then strip the closure off
             * to raise up runtime effeciency only if it's a single op that wrapped
             * with closure
             */
            int innerRoot = -1;
            int arity = -1;
            //flag to verify if basics[0] is a single op or a combo term.
            boolean simpleOp = false;
            if (term.getRoot() == closureId) {
                Term innerTerm = term.getArgument((byte) 0);
                innerRoot = innerTerm.getRoot();
                simpleOp = true;
                arity = (MapTable.getSymbol(innerRoot)).arity;
            }
            Term[] argument = new Term[argsNumber];
            for (int i = 0; i < argsNumber; i++)
                argument[i] = (Term) ((Ast.Expr) functionExpr.basics.elementAt(i + 1)).accept(this, o);
            /**
             * check if it is already fully applied, then strip the closure off
             * to raise up runtime effeciency only if it's a single op that wrapped
             * with closure
             */
            if (simpleOp && (arity == argsNumber))
                term = new Term(innerRoot, argument);
            else {
                for (int i = 0; i < argsNumber; i++)
                    term = new Term(applyId, new Term[]{term, argument[i]});
            }
        }
        return term;
    }

    public Object visit(Ast.Uminus uminus, Object o) {
        Term zeroTerm = new Term(0);
        Term[] tmp = new Term[]{zeroTerm, (Term) uminus.e.accept(this, o)};
        Term term = new Term(negateId, tmp);
        return term;
    }

    public Object visit(Ast.PrefixId prefixId, Object o)
            throws AmbiguousSymbolException {
        Vector variableList = (Vector) o;
        Term term = null;
        if (variableList.contains(prefixId.id)) {
            //it is a variable
            term = VarFactory.getVariable(prefixId.id);
        } else {
            //it is a constructor or function
	    int root = getRoot(prefixId.id);
            term = new Term(root, new Term[0]);
            int arity = (MapTable.getSymbol(root)).arity;
            //it is a constructor or function which takes argument, wrap it with closure
            //otherwise it has no argument
            if (arity > 0) {
                term = new Term(closureId, new Term[]{term});
            }
        }
        return term;
    }

    public Object visit(Ast.InfixId infixId, Object o)
            throws AmbiguousSymbolException {
        int root = ModuleTable.getId(infixId.id);
        Term term = new Term(root, new Term[0]);
        term = new Term(closureId, new Term[]{term});
        return term;
    }

    public Object visit(Ast.ParenExpr parenExpr, Object o) {
        return parenExpr.e.accept(this, o);
    }

    // TODO: LeftSection and RightSection are the same!! change them !!!!
    public Object visit(Ast.LeftSection leftSection, Object o)
            throws AmbiguousSymbolException {
        // apply(apply(closure(leftSection),exp),closure(op))
        Term[] tmp = new Term[0];
        Term leftSectionTerm = new Term(leftSectionId, tmp);

        tmp = new Term[]{leftSectionTerm};
        Term closureLeftSection = new Term(closureId, tmp);

        Term exprTerm = (Term) leftSection.e.accept(this, o);
        tmp = new Term[]{closureLeftSection, exprTerm};
        Term applyClosureLeftSectionExprTerm = new Term(applyId, tmp);

        int root = ModuleTable.getId(leftSection.infixId);
        tmp = new Term[0];
        Term infixTerm = new Term(root, tmp);

        tmp = new Term[]{infixTerm};
        Term closureInfixTerm = new Term(closureId, tmp);
        tmp = new Term[]{applyClosureLeftSectionExprTerm, closureInfixTerm};
        Term finalTerm = new Term(applyId, tmp);
        return finalTerm;
    }


    // TODO: LeftSection and RightSection are the same!! change them !!!!
    public Object visit(Ast.RightSection rightSection, Object o)
            throws AmbiguousSymbolException {
        // apply(apply(closure(rightSection),exp),closure(op))
        Term[] tmp = new Term[0];
        Term rightSectionTerm = new Term(rightSectionId, tmp);

        tmp = new Term[]{rightSectionTerm};
        Term closureRightSection = new Term(closureId, tmp);

        Term exprTerm = (Term) rightSection.e.accept(this, o);
        tmp = new Term[]{closureRightSection, exprTerm};
        Term applyClosureRightSectionExprTerm = new Term(applyId, tmp);

        int root = ModuleTable.getId(rightSection.infixId);
        tmp = new Term[0];
        Term infixTerm = new Term(root, tmp);

        tmp = new Term[]{infixTerm};
        Term closureInfixTerm = new Term(closureId, tmp);

        //check if the operation is overloaded
        DataSymbol infixSymbol = MapTable.getSymbol(infixTerm.getRoot());
        tmp = new Term[]{applyClosureRightSectionExprTerm, closureInfixTerm};
        Term finalTerm = new Term(applyId, tmp);
        return finalTerm;
    }

    public Object visit(Ast.IntLit intLit, Object o) {
        return new Term(intLit.i);
    }

    public Object visit(Ast.CharLit charLit, Object o) {
        return new Term(charLit.c);
    }

    public Object visit(Ast.StringLit stringLit, Object o) {
        char[] charArray = stringLit.s.toCharArray();
        Vector v = new Vector();
        for (int i = 0; i < charArray.length; ++i) {
            v.add(new Ast.CharLit(charArray[i]));
        }
        return visit(new Ast.List(v), o);
    }

    public Object visit(Ast.FloatLit floatLit, Object o) {
        return new Term(floatLit.f);
    }

    public Object visit(Ast.List list, Object o) {
        //construct nil term at the end
        Term[] tmp = new Term[0];
        Term term = new Term(nilId, tmp);

        //if nil
        if (list.listListing.size() == 0) {
            return term;
        }
        for (int i = list.listListing.size() - 1; i >= 0; i--) {
            tmp = new Term[]{(Term) ((Ast.Expr) list.listListing.elementAt(i)).accept(this, o), term};
            term = new Term(consId, tmp);
        }
        return term;
    }

    public Object visit(Ast.Tuple tuple, Object o) {
        Vector tupleListing = tuple.tupleListing;
        Term term = null;
        //if unit tuple
        if (tupleListing.size() == 0) {
            Term[] tmp = new Term[0];
            term = new Term(unitId, tmp);
        } else if (tupleListing.size() == 2) {//it is a tuple2
            Term[] tmp = new Term[]{(Term) ((Ast.Expr) tupleListing.elementAt(0)).accept(this, o), (Term) ((Ast.Expr) tupleListing.elementAt(1)).accept(this, o)};
            term = new Term(pairId, tmp);
        } else {//it is multi-dimentional tuple, create now
            int dimension = tupleListing.size();

            Term[] tmp = new Term[dimension];
            for (int i = 0; i < dimension; ++i) {
                tmp[i] = (Term) ((Ast.Expr) tupleListing.elementAt(i)).accept(this, o);
            }
            //if this dimensional tuple not created yet
            String tupleConstructorName
		= TupleModule.getConstructorName(dimension);

            if (ModuleTable.getSymbol(TupleModule.moduleName,
				      tupleConstructorName) == null) {
                // if we get null, then we need to store the symbol in
                // the table manually
                ConstructorSymbol ntuple
		    = TupleModule.makeTupleConstructor(dimension);
                ModuleTable.installSymbol(TupleModule.moduleName, ntuple);
		// Type symbols are NOT installed into the symbol table
            }
            int root = ModuleTable.getId(TupleModule.moduleName,
					 tupleConstructorName);
            term = new Term(root, tmp);
        }
        return term;
    }


    /**
     * create four functions enumFrom e, ...
     */
    public Object visit(Ast.ArithSeq arithSeq, Object o) {
        Term term = null;
        if ((arithSeq.then == null) && (arithSeq.to == null)) {
            Term[] tmp = new Term[]{(Term) (arithSeq.from).accept(this, o)};
            term = new Term(enumFromId, tmp);
        }
        if ((arithSeq.then != null) && (arithSeq.to == null)) {
            Term[] tmp = new Term[]{(Term) (arithSeq.from).accept(this, o),
				    (Term) (arithSeq.then).accept(this, o)};
            term = new Term(enumFromThenId, tmp);
        }
        if ((arithSeq.then == null) && (arithSeq.to != null)) {
            Term[] tmp = new Term[]{(Term) (arithSeq.from).accept(this, o),
				    (Term) (arithSeq.to).accept(this, o)};
            term = new Term(enumFromToId, tmp);
        }
        if ((arithSeq.then != null) && (arithSeq.to != null)) {
            Term[] tmp = new Term[]{(Term) (arithSeq.from).accept(this, o),
				    (Term) (arithSeq.then).accept(this, o),
				    (Term) (arithSeq.to).accept(this, o)};
            term = new Term(enumFromThenToId, tmp);
        }
        return term;
    }

}
