(*
 *)

datatype MLtype =
    Unit
  | Int
  | Char
  | Bool
  | Product of MLtype list
  | Arrow of (MLtype * MLtype)
  | Tvar of (MLtype option) ref;


val t1 = Arrow(Int,Bool)
val t2 = Product[t1,Char,Unit]
val t3 = Arrow(Int,Arrow(Int,Bool));
val t4 = Product[t1,t2,t3,Bool];

fun prune (Tvar(r as (ref(SOME x)))) =
          let val last = prune x
          in r := SOME last; last end
  | prune x = x;


fun showt Unit = "()"
  | showt Int = "int"
  | showt Char = "char"
  | showt Bool = "bool"
  | showt (Product ts) =
     let fun showeach [] = ""
           | showeach [x] = showt x
           | showeach (x::xs) = (showt x)^"*"^(showeach xs)
     in "("^(showeach ts)^")" end
  | showt (Arrow(x,y)) = "("^showt x^" -> "^showt y^")"
  | showt (Tvar(ref NONE)) = "`a"
  | showt (Tvar(ref (SOME z))) = showt z;

(*
showt t1

showt t2

showt t3
*)

fun typeeq (x,y) =
case (prune x,prune y) of
  (Unit,Unit) => true
| (Int,Int) => true
| (Char,Char) => true
| (Bool,Bool) => true
| (Arrow(d1,r1),Arrow(d2,r2)) => typeeq(d1,d2)  andalso
                                 typeeq(r1,r2)
| (Product(ss),Product(ts)) =>  (listeq ss ts)
| (Tvar(r as (ref NONE)),t) => (r := SOME t; true)
| (t,Tvar(r as (ref NONE))) => (r := SOME t; true)
| (_,_) => false

and listeq (x::xs) (y::ys) =
          typeeq(x,y) andalso listeq xs ys
    | listeq [] [] = true
    | listeq _ _ = false;

(*
typeeq(t1,t2);

typeeq(t1,t1);

typeeq(Bool,Bool)
*)


datatype Op = Plus | Less | And

datatype Constant
  = Cint of int
  | Cchar of char
  | Cbool of bool

datatype Exp
  = Lit of Constant          (* 5                  *)
  | Var of string            (* x                  *)
  | App of Exp*Exp           (* f x                *)
  | Tuple of Exp list        (* (x,3,true)         *)
  | Infix of Exp*Op*Exp      (* x+3                *)
  | Stmt of Exp list         (* (print x; 3)       *)
  | If of Exp * Exp * Exp    (* if x then y else 3 *)
  | While of Exp * Exp       (* while x do (f x)   *)
  | Anonfun of string* MLtype * Exp  (* (fn x => x+1)      *)
  | Let of Dec*Exp           (* let val x = 1 in x end *)

and Dec
  = Valdec of string*Exp
  | Fundec of string*string*MLtype*Exp


val exp1 = Lit(Cint 5);
val exp2 = Var "x";
val exp3 = App(Var "f",Var "x");
val exp4 = Tuple[Var "x",Lit(Cint 3),Lit(Cbool true)];
val exp5 = Infix(Var "x",Plus,Lit(Cint 3));
val exp6 = Stmt[App(Var "print",Var "x"),Lit(Cint 3)];
val exp7 = If(Var "x",Var "y",Lit(Cint 3));
val exp8 = While(Var "x",App(Var "f",Var "x"));
val exp9 = Anonfun("x",Int,Infix(Var "x",Plus,Lit(Cint 1)));
val exp10 = Let(Valdec("x",Lit(Cint 1)),Var "x");


exception TypeError of Exp*string;

fun error e s = raise(TypeError (e,s));

fun TCOp Plus = (Int,Int,Int)
  | TCOp Less = (Int,Int,Bool)
  | TCOp And  = (Bool,Bool,Bool);

fun TCConstant (Cint n) = Int
  | TCConstant (Cchar c) = Char
  | TCConstant (Cbool t) = Bool;


fun unexpected r t1 t2 =
error r ("Found type "^(showt t1)^" expecting type "^(showt t2));

fun TCExp x cntxt =
  case x of
    Lit c => TCConstant c
  | Var s =>
     (case List.find (fn (nm,t) => nm=s) cntxt of
          SOME(nm,t) => t
      | NONE => error x "Undeclared variable")
  | Infix(l,x,r) =>
     let val ltype = TCExp l cntxt
         val rtype = TCExp r cntxt
         val (lneed,rneed,result) = TCOp x
     in case (typeeq(ltype,lneed),typeeq(rtype,rneed)) of
          (true,true) => result
        | (true,false) => unexpected r rtype rneed
        | (false,true) => unexpected l ltype lneed
        | (false,false) => unexpected l ltype lneed
     end
   | App(f,x) =>
      let val ftype = TCExp f cntxt
          val xtype = TCExp x cntxt
      in case prune ftype of
           Arrow(dom,rng) =>
             if (typeeq(dom,xtype))
                then rng
                else unexpected x xtype dom
         | other => error f ("the type "^showt other^" is not a function")
      end
   | Stmt xs =>
      let val xstypes = List.map (fn x => TCExp x cntxt) xs
          fun last [x] = x
            | last (x::xs) = last xs
            | last [] = error x "Tuple with no elements"
      in last xstypes end
   | Tuple xs =>
      let val xstypes = List.map (fn x => TCExp x cntxt) xs
      in Product xstypes end
   | If(x,y,z) =>
      let val xtype = TCExp x cntxt
          val ytype = TCExp y cntxt
          val ztype = TCExp z cntxt
      in if typeeq(xtype,Bool)
            then if (typeeq(ytype,ztype))
                    then ytype
                    else unexpected y ytype ztype
            else error x ("the type "^
                          showt xtype^
                          " is not boolean")
      end
   | While(test,body) =>
      let val ttype = TCExp test cntxt
          val btype = TCExp body cntxt
      in if typeeq(ttype,Bool)
            then btype
            else unexpected test ttype Bool
      end

(* | Anonfun(x,t,body) =>
      let val btype = TCExp body ((x,t)::cntxt)
      in Arrow(t,btype) end *)

   | Anonfun(x, _ ,body) =>
     let val t = Tvar(ref NONE)
         val btype =
             TCExp body ((x,t)::cntxt)
      in Arrow(prune t,btype) end

   | Let(d,b) =>
     let val (_,cntxt2) = TCDec d cntxt
         val btype = TCExp b cntxt2
     in btype end

and TCDec (Valdec(nm,exp)) cntxt =
      let val nmtype = TCExp exp cntxt
      in (nmtype,(nm,nmtype)::cntxt) end
  | TCDec (Fundec(nm,arg,argtype,body)) cntxt =
      let val bodytype = TCExp body ((arg,argtype)::cntxt)
          val nmtype = Arrow(argtype,bodytype)
          val newcntxt = (nm,nmtype)::cntxt
      in (nmtype,newcntxt)
      end



(* 
*)