(* Type checking for the core calculus. *) module Check : sig val check : Tcore.exp -> Tcore.typ val check_in_env : Tcore.typ Env.env -> Tcore.exp -> Tcore.typ exception Incomparable val join_type_list: Tcore.typ list -> Tcore.typ val subtype : Tcore.typ -> Tcore.typ -> bool val initial_env: Tcore.typ Env.env end = struct open Tcore (* Calculate LUB of two types *) exception Incomparable let rec join_types (t1:typ) (t2:typ) : typ = match (t1,t2) with (IntT,IntT) -> IntT | (ArrowT(t1a,t1r),ArrowT(t2a,t2r)) -> ArrowT(meet_types t1a t2a,join_types t1r t2r) | (RecordT r1,RecordT r2) -> RecordT(intersect_rows join_types r1 r2) | (SumT r1,SumT r2) -> SumT(union_rows join_types r1 r2) | _ -> raise Incomparable and meet_types (t1:typ) (t2:typ) : typ = match (t1,t2) with (IntT,IntT) -> IntT | (ArrowT(t1a,t1r),ArrowT(t2a,t2r)) -> ArrowT(join_types t1a t2a,meet_types t1r t2r) | (RecordT r1,RecordT r2) -> RecordT(union_rows meet_types r1 r2) | (SumT r1,SumT r2) -> SumT(intersect_rows meet_types r1 r2) | _ -> raise Incomparable and intersect_rows (combine: typ -> typ -> typ) (r1:rtyp) (r2:rtyp) : rtyp = match (r1,r2) with (_,[]) -> [] | ([],_) -> [] | ((l1,t1)::r1',(l2,t2)::r2') -> if l1 = l2 then try (l1,combine t1 t2)::(intersect_rows combine r1' r2') with Incomparable -> intersect_rows combine r1' r2' else if l1 < l2 then intersect_rows combine r1' r2 else intersect_rows combine r1 r2' and union_rows (combine: typ -> typ -> typ) (r1:rtyp) (r2:rtyp) : rtyp = match (r1,r2) with (_,[]) -> r1 | ([],_) -> r2 | ((l1,t1)::r1',(l2,t2)::r2') -> if l1 = l2 then (l1,combine t1 t2)::(union_rows combine r1' r2') else if l1 < l2 then (l1,t1)::(union_rows combine r1' r2) else (l2,t2)::(union_rows combine r1 r2') (* Calculate LUB of non-empty list of types *) let rec join_type_list = function [] -> raise (Failure "joint_type_list on []") | [x] -> x | x1::x2::xs -> join_type_list(join_types x1 x2::xs) (* Check t1 a subtype of t2 *) let rec subtype (t1:typ) (t2:typ) : bool = match (t1,t2) with (IntT,IntT) -> true | (ArrowT(t1a,t1r),ArrowT(t2a,t2r)) -> subtype t2a t1a && subtype t1r t2r | (RecordT r1,RecordT r2) -> let f (l,t) = try subtype (List.assoc l r1) t with Not_found -> false in List.for_all f r2 | (SumT r1,SumT r2) -> let f (l,t) = try subtype t (List.assoc l r2) with Not_found -> false in List.for_all f r1 | _ -> false (* A simpler, but less efficient alternative: let rec subtype (t1:typ) (t2:typ) : bool = (join_types t1 t2) = t2 *) exception Type_error of string * exp exception Row_error of string * row let rec type_exp (env : typ Env.env) : exp -> typ = let rec type_exp0 (e: exp) : typ = match e with Var v -> Env.lookup env v | Abs (v,t,e) -> ArrowT(t,type_exp (Env.extend env v t) e) | App (e1,e2) -> begin match (type_exp0 e1) with ArrowT(t2,t) -> if subtype (type_exp0 e2) t2 then t else raise (Type_error ("actual doesn't match formal",e)) | _ -> raise (Type_error ("operator not a function",e)) end | Int _ -> IntT | Fix(fs,e2) -> let ext env (f,e0,rt) = match e0 with Abs(_,vt,_) -> Env.extend env f (ArrowT(vt,rt)) | _ -> raise (Type_error ("recursively-defined identifier must be abstraction",e)) in let env1 = List.fold_left ext env fs in let chk (_,e0,rt) = match e0 with Abs(v,vt,e1) -> if rt <> (type_exp (Env.extend env1 v vt) e1) then raise (Type_error ("return type annotation doesn't match",e)) else () | _ -> raise (Type_error ("recursively-defined identifier must be abstraction",e)) (* can't happen! *) in List.iter chk fs; type_exp env1 e2 | Record r -> RecordT (type_row0 r) | Select(l,e1) -> begin match type_exp0 e1 with RecordT rtyp -> begin try List.assoc l rtyp with Not_found -> raise (Type_error("bad label",e)) end | _ -> raise (Type_error ("select from non-record",e)) end | Variant(l,e1) -> SumT[(l,type_exp0 e1)] | Switch(_,[]) -> raise (Type_error ("switch with no arms",e)) | Switch(e1,r) -> begin let t1 = type_exp0 e1 in match t1 with SumT rtyp -> let rt = type_row0 r in let (arg_rt,res_typs) = let f = function (lab,ArrowT(arg,res)) -> ((lab,arg),res) | _ -> raise (Type_error ("switch arm not lambda",e)) in List.split (List.map f rt) in if not (subtype t1 (SumT arg_rt)) then raise (Type_error ("switch arm label missing",e)) else begin try join_type_list res_typs with Incomparable -> raise (Type_error ("switch arm result types inconsistent",e)) end | _ -> raise (Type_error ("switch on non-sum",e)) end and type_row0 (r: row) : rtyp = (* returned row type is in label-sorted order *) let r' = Sort.list (<=) r in let rec check_nodups l = match l with [] -> () | [(h,_)] -> () | ((h1,_)::(((h2,_)::_) as t)) -> if h1 = h2 then raise (Row_error("duplicate label in row",r)) else check_nodups t in check_nodups r'; List.map (fun (l,e) -> (l, type_exp0 e)) r' in type_exp0 (* Initial environment defining primitive operation types *) let initial_env = List.fold_left2 Env.extend Env.empty primops [ArrowT(tPairT,IntT); ArrowT(tPairT,IntT); ArrowT(tPairT,tBoolT)] let check_in_env = type_exp let check e = check_in_env initial_env e end