(*
  *)

(* The code for illustrating Mutual Recursion *)

datatype MLtype =
    Unit
  | Int
  | Real
  | Char
  | Bool
  | Product of MLtype list
  | Arrow of (MLtype * MLtype)

fun showt Unit = "()"
  | showt Int = "int"
   | showt Real = "real"
  | showt Char = "char"
  | showt Bool = "bool"
  | showt (Product ts) =
     let fun showeach [] = ""
           | showeach [x] = showt x
           | showeach (x::xs) = (showt x)^"*"^(showeach xs)
     in "("^(showeach ts)^")" end
  | showt (Arrow(x,y)) = "("^showt x^" -> "^showt y^")"

fun typeeq (x,y) =
case (x,y) of
  (Unit,Unit) => true
| (Int,Int) => true
| (Real,Real) => true
| (Char,Char) => true
| (Bool,Bool) => true
| (Arrow(d1,r1),Arrow(d2,r2)) => typeeq(d1,d2)  andalso
                                 typeeq(r1,r2)
| (Product(ss),Product(ts)) =>  (listeq ss ts)
| (_,_) => false

and listeq (x::xs) (y::ys) =
          typeeq(x,y) andalso listeq xs ys
    | listeq [] [] = true
    | listeq _ _ = false;

(* **************************************************** *)

datatype Op = Plus | Less | And

datatype Constant
  = Cint of int
  | Cchar of char
  | Cbool of bool

datatype Exp
  = Lit of Constant          (* 5                  *)
  | Var of string            (* x                  *)
  | App of Exp*Exp           (* f x                *)
  | Tuple of Exp list        (* (x,3,true)         *)
  | Infix of Exp*Op*Exp      (* x+3                *)
  | Stmt of Exp list         (* (print x; 3)       *)
  | If of Exp * Exp * Exp    (* if x then y else 3 *)
  | While of Exp * Exp       (* while x do (f x)   *)
  | Anonfun of string* MLtype * Exp  (* (fn x => x+1)      *)
  | Let of Dec*Exp           (* let val x = 1 in x end *)
  | Int2Real of Exp

and Dec
  = Valdec of (string*MLtype)*Exp
  | Fundec of (string*MLtype*MLtype)*string*Exp
  | Mutdec of Dec list

fun TCConstant (Cint n) = Int
  | TCConstant (Cchar c) = Char
  | TCConstant (Cbool t) = Bool;


exception TypeError of Exp*string;
fun error e s = raise(TypeError (e,s));
fun unexpected r t1 t2 =
error r ("Found type "^(showt t1)^" expecting type "^(showt t2));

fun TCExp x cntxt =
  case x of
    Lit c => TCConstant c
  | Var s =>
     (case List.find (fn (nm,t) => nm=s) cntxt of
          SOME(nm,t) => t
      | NONE => error x "Undeclared variable")


and TCDec (Valdec((nm,t),exp)) cntxt =
      let val bodyt = TCExp exp cntxt
      in if typeeq(t,bodyt)
            then [(nm,t)]
            else unexpected exp bodyt t  end
  | TCDec (Fundec((f,dom,rng),x,body)) cntxt =
      let val ft = Arrow(dom,rng)
          val bodyt = TCExp body((x,dom)::cntxt)
          (* f is not recursive unless inside a Mutdec *)
      in if typeeq(bodyt,rng)
                  then [(f,ft)]
                  else unexpected body bodyt rng
      end
  | TCDec (Mutdec ds) cntxt =
      let fun pass1 [] cntxt = cntxt
            | pass1 (Valdec(p,b)::ds) cntxt = pass1 ds (p::cntxt)
            | pass1 (Fundec((f,d,r),x,b)::ds) cntxt =
                    pass1 ds ((f,Arrow(d,r))::(x,d)::cntxt)
          val temp = pass1 ds cntxt
          val pass2 = map (fn d => TCDec d temp) ds
      in List.concat pass2 end;


fun TCOp Plus = (Int,Int,Int)
  | TCOp Less = (Int,Int,Bool)
  | TCOp And  = (Bool,Bool,Bool);

fun TCConstant (Cint n) = Int
  | TCConstant (Cchar c) = Char
  | TCConstant (Cbool t) = Bool;


(* ********* Code rebuilding type checker ******** *)

fun fst (x,y) = x
fun snd (x,y) = y

fun Fix result info left right =
  case info of
    (Plus,Int,Real) => (result,Infix(Int2Real left,Plus,right))
  | (Plus,Real,Int) => (result,Infix(left,Plus,Int2Real right))
  | (oper, _ , _ ) => (result,Infix(left,oper,right))

fun TCExp2 x cntxt =
  case x of
    Lit c => (TCConstant c,x)
  | Var s =>
     (case List.find (fn (nm,t) => nm=s) cntxt of
          SOME(nm,t) => (t,Var s)
      | NONE => error x "Undeclared variable")
  | Infix(l,x,r) =>
     let val (ltype,l2) = TCExp2 l cntxt
         val (rtype,r2) = TCExp2 r cntxt
         val (lneed,rneed,result) = TCOp x
     in case (typeeq(ltype,lneed),typeeq(rtype,rneed)) of
          (true,true) => Fix result (x,ltype,rtype) l2 r2
        | (true,false) => unexpected r rtype rneed
        | (false,true) => unexpected l ltype lneed
        | (false,false) => unexpected l ltype lneed
     end
   | App(f,x) =>
      let val (ftype,f2) = TCExp2 f cntxt
          val (xtype,x2) = TCExp2 x cntxt
      in case ftype of
           Arrow(dom,rng) =>
             if (typeeq(dom,xtype))
                then (rng,App(f2,x2))
                else unexpected x xtype dom
         | other => error f ("the type "^showt other^" is not a function")
      end
   | Stmt xs =>
      let val pairs = List.map (fn x => TCExp2 x cntxt) xs
          val xs2 = List.map snd pairs
          fun last [x] = x
            | last (x::xs) = last xs
            | last [] = error x "Tuple with no elements"
      in (fst(last pairs), Stmt xs2) end
  | Tuple xs =>
      let val pairs = List.map (fn x => TCExp2 x cntxt) xs
          val xstypes = map fst pairs
          val xs2 = map snd pairs
      in (Product xstypes, Tuple xs2) end
   | If(x,y,z) =>
      let val (xtype,x2) = TCExp2 x cntxt
          val (ytype,y2) = TCExp2 y cntxt
          val (ztype,z2) = TCExp2 z cntxt
      in if typeeq(xtype,Bool)
            then if (typeeq(ytype,ztype))
                    then (ytype,If(x2,y2,z2))
                    else unexpected y ytype ztype
            else error x ("the type "^
                          showt xtype^
                          " is not boolean")
      end
   | While(test,body) =>
      let val (ttype,test2) = TCExp2 test cntxt
          val (btype,body2) = TCExp2 body cntxt
      in if typeeq(ttype,Bool)
            then (btype,While(test2,body2))
            else unexpected test ttype Bool
      end
| Anonfun(x,t,body) =>
      let val (btype,body2) = TCExp2 body ((x,t)::cntxt)
      in (Arrow(t,btype),Anonfun(x,t,body2)) end
| Let(d,b) =>
      let val cntxt2 = TCDec d cntxt
          val (btype,b2) = TCExp2 b cntxt2
      in (btype,Let(d,b2)) end;

(* ********** Scopes ******************* *)


type ('name,'value)table = ('name * 'value) list;
type ('name,'value)Scope = (('name,'value)table) list ;


fun initialize scope = [ ] :: scope
fun insert name value (table::scope)
 = ((name,value)::table)::scope
fun Lookup name (table::scope) =
  case List.find (fn (x,y) => x=name) table of
    NONE => error
  | SOME (x,y) => y
fun finalize (table::scope) = scope



(* 
*)