--

import ST

---------------------------------------------------------
-- Inference Monad

newtype IM a x = Ck (Int -> (ST a (x, String, Int)))

instance Functor (IM a) where
  fmap f (Ck g) = Ck h
      where h n = do { (x, out, n') <- g n
                     ; return (f x,out,n')
                     }
                    
                   
instance Monad (IM a) where 
  return x = Ck h
      where h n = return (x, "", n)
  (Ck g) >>= f = Ck ff
      where ff n = do { (a, out1, n1) <- g n
                      ; let (Ck h) = f a
                      ; (y, out2, n2) <- h n1
                      ; return (y, out1 ++ out2, n2)
                      }
                      
readVar  :: STRef a b -> IM a b
newVar   :: a -> IM c (STRef c a)
writeVar :: STRef a b -> b -> IM a ()
                       
readVar ref = Ck f 
    where f n = do { z <- readSTRef ref
                   ; return (z, "", n)
                   }
                 
newVar init = Ck f
    where f n = do { z <- newSTRef init
                   ; return (z, "", n)
                   }
                 
writeVar ref value = Ck f
    where f n = do { z <- writeSTRef ref value
                   ; return (z, "", n)
                   } 

nextN = Ck f
    where f n = return (n, "", n+1)
    

printS s = Ck f
    where f n = return ((), s, n)

pr :: Show a => [Char] -> a -> IM b ()
pr s x = printS (s ++ (show x) ++ " -\n")


runIM :: (forall a . IM a c) -> Int -> (c,String,Int)
runIM w n = let (Ck f) = w in runST (f n)
  
  
force :: (forall a . IM a c) -> c
force w =
    case (runIM w) 0 of
     (x, _, _) -> x


---------------------------------------------------------------  
-- Representing Types     

data Type a = 
    Tunit 
  | Tarrow (Type a) (Type a)
  | Ttuple [ Type a ]
  | Tdata String [ Type a ]
  | Tgen Int
  | Tvar (STRef a (Maybe (Type a)))

data Scheme a = Sch [Int] (Type a)

class Error a b where
  occursCk :: Type a -> Type a -> IM a b
  nameMtch::  Type a -> Type a -> IM a b
  shapeMtch:: Type a -> Type a -> IM a b
  tupleLenMtch:: Type a -> Type a -> IM a b


unify :: Error a [String] => Type a -> Type a -> IM a [String]
unify tA tB = 
 do { t1 <- prune1 tA
    ; t2 <- prune1 tB
    ; case (t1,t2) of
       (Tvar r1,Tvar r2) ->   -- Both are Variables
          if r1==r2 
             then return []
             else do { writeVar r1 (Just t2); return []}
       (Tvar r1,_) ->      -- One is a Variable
          do { b <- occursIn1 r1 t2
             ; if b then occursCk t1 t2
                    else do { writeVar r1 (Just t2); return [] }
             }
       (_,Tvar r2) -> unify t2 t1
       (Tgen s, Tgen t) ->
          if s==t then return [] else (nameMtch t1 t2)
       (Tarrow x y,Tarrow m n) ->
          do { cs1 <- unify x m
             ; cs2 <- unify y n
             ; return (cs1 ++ cs2)
             }
       (Ttuple xs, Ttuple ys) ->
         if (length xs) == (length ys)  
            then do { xss <- sequence (fmap (uncurry unify) (zip xs ys))
                    ; return (concat xss)
                    }
            else tupleLenMtch t1 t2
       (_,_) -> (shapeMtch t1 t2)
    }


prune1 (typ @ (Tvar ref)) =
 do { m <- readVar ref
    ; case m of
        Just t -> do { newt <- prune1 t
                     ; writeVar ref (Just newt)
                     ; return newt
                     } 
        Nothing -> return typ
    }
prune1 typ = return typ


occursIn1 r t =
 do { t2 <- prune1 t 
    ; case t2 of 
        Tunit -> return False
        Tarrow x y ->
           do { b1 <- occursIn1 r x
              ; b2 <- occursIn1 r y
              ; return ((||) b1 b2 )
              }
        Ttuple xs ->
          do { bs <- sequence (map (occursIn1 r) xs)
             ; return (or bs)
             }
        Tdata name xs ->
          do { bs <- sequence (map (occursIn1 r) xs)
             ; return (or bs)
             }
        Tgen s -> return False
        Tvar z -> return(r == z)
    }

---------------------------------------------------------------
-- Generalize

type Tref a =  STRef a (Maybe (Type a))

gen :: (Tref a -> IM a Bool) -> Type a -> [(Tref a,Int)] -> IM a (Type a,[(Tref a,Int)])
gen pred t pairs =
  do { t1 <- prune1 t
     ; case t1 of
         Tvar r -> do { b <- pred r
                      ; if b then genVar r pairs 
                             else return(t1,pairs)}
         Tgen n -> return(t1,pairs)
         Tunit -> return(t1,pairs)
         Tarrow x y ->
           do { (x',p1) <- gen pred x pairs
              ; (y',p2) <- gen pred y p1
              ; return (Tarrow x' y',p2)
              }
         Ttuple ts  -> do { (ts',p) <- thread pred ts pairs
                          ; return (Ttuple ts',p) }
         Tdata c ts -> do { (ts',p) <- thread pred ts pairs
                          ; return (Tdata c ts',p) }
     }            

thread p [] pairs = return ([],pairs)
thread p (t:ts) pairs = do { (t',p1) <- gen p t pairs
                         ; (ts',p2) <- thread p ts p1
                         ; return(t':ts',p2)
                         }

genVar :: Tref a -> [(Tref a,Int)] -> IM a (Type a,[(Tref a,Int)])
genVar r [] = do { n <- nextN
                 ; return (Tgen n,[(r,n)]) }
genVar r (ps @ ((p @ (r1,n)):more)) = 
       if r1==r then return (Tgen n,ps)
                else do { (t,ps) <- genVar r more
                        ; return (t,p:ps)}


generalize :: (Tref a -> IM a Bool) -> Type a -> IM a (Scheme a)
generalize p t =
  do { (t',pairs) <- gen p t []
     ; return(Sch (map snd pairs) t')
     }

freshVar = do { r <- newVar Nothing; return (Tvar r) }     

instantiate (Sch ns t) =
  do { ts <- sequence(map (\ _ -> freshVar) ns)
     ; let sub = zip ns ts
     ; subGen sub t
     }

subGen sub t =
 do { t2 <- prune1 t 
    ; case t2 of 
        Tunit -> return Tunit
        Tarrow x y ->
           do { b1 <- subGen sub x
              ; b2 <- subGen sub y
              ; return (Tarrow b1 b2)
              }
        Ttuple xs ->
          do { bs <- sequence (map (subGen sub) xs)
             ; return (Ttuple bs)
             }
        Tdata name xs ->
          do { bs <- sequence (map (subGen sub) xs)
             ; return (Tdata name bs)
             }
        Tgen s -> return(find s sub)
        Tvar z -> return(Tvar z)
    }

find s [] = error "unknown generic var"
find s ((a,b):xs) = if s==a then b else find s xs
----------------------------------------------------------------

data Exp 
 = App Exp Exp
 | Abs String Exp
 | Var String
 | Tuple [ Exp]
 | Const Int
 | Let String Exp Exp
 | MutRecLet [(String,Exp)] Exp
 | PatAbs Pat Exp
 | MutPatRecLet (Pat,Exp) (Pat,Exp) Exp
 
data Pat 
 = PVar String
 | PTuple [ Pat ]
 | PWildcard

toScheme (x,t) = (x,Sch [] t)

inferPat :: [(String,Type a)] -> Pat -> IM a (Type a,[(String,Type a)])
inferPat delta PWildcard = 
  do { typ <- freshVar
     ; return (typ,delta)
     }
inferPat delta (PVar s) = 
  do { typ <- freshVar
     ; return (typ,(s,typ):delta)
     }
inferPat delta (PTuple ps) =
  do { let f [] delta = return ([],delta)
           f (p:ps) delta =
              do { (t,delta2) <- inferPat delta p
                 ; (ts,delta3) <- f ps delta2
                 ; return (t:ts,delta3)
                 }
     ; (ts,delta') <- f ps delta
     ; return (Ttuple ts,delta')
     }

infer :: Error a [String] => Exp -> [(String,Scheme a)] -> IM a (Type a)
infer e env =
 case e of
   Var s -> instantiate (find s env)
   App f x -> 
     do { ftyp <- infer f env
        ; xtyp <- infer x env
        ; result <- freshVar
        ; unify (Tarrow xtyp result) ftyp
        ; return result
        }
   Abs x e ->
     do { xtyp <- freshVar
        ; etyp <- infer e ((x,Sch [] xtyp):env)
        ; return(Tarrow xtyp etyp)
        }
   PatAbs pat e ->
     do { (ptyp,delta) <- inferPat [] pat
        ; env2 <- return((map toScheme delta) ++ env)
        ; etyp <- infer e env2
        ; return(Tarrow ptyp etyp)
        }
   Tuple es ->
     do { ts' <- sequence(map (\ e -> infer e env) es)
        ; return(Ttuple ts')
        } 
   Const n -> return(Tdata "Int" [])
   Let x e b ->
     do { xtyp <- freshVar
        ; etyp <- infer e ((x,Sch [] xtyp):env)
        ; unify xtyp etyp
        ; schm <- generalize (generic env) etyp
        ; btyp <- infer b ((x,schm):env)
        ; return btyp
        }
   MutPatRecLet (p1,e1) (p2,e2) b ->
     do { (t1,delta1) <- inferPat [] p1
        ; (t2,delta2) <- inferPat delta1 p2
        ; env2 <- return((map toScheme delta2) ++ env)
        ; etyp1 <- infer e1 env2
        ; unify etyp1 t1
        ; etyp2 <- infer e2 env2
        ; unify etyp2 t2
        ; env3 <- polyEnv env delta2 delta2
        ; btyp <- infer b env3
        ; return btyp
        }
   MutRecLet ds b ->
     do { envDelta <- monoEnv [] ds
        ; let env2 =(map toScheme envDelta) ++ env
        ; doEach env2 envDelta ds
        ; env3 <- polyEnv envDelta
        ; btype <- infer b env3
        ; return btype
        }         
         where monoEnv env [] = return env
               monoEnv env ((x,e):ds) =
                 do { xtyp <- freshVar
                    ; monoEnv ((x,xtyp):env) ds }
               toScheme (x,t) = (x,Sch [] t)
               doEach env2 [] [] = return ()
               doEach env2 ((x,xtyp):ts) ((_,e):ds) =
                 do { etyp <- infer e env2
                    ; unify xtyp etyp 
                    ; doEach env2 ts ds }
               polyEnv [] = return []
               polyEnv ((x,xtyp):ts) =
                 do { schm <- generalize (generic env) xtyp
                    ; ss <- polyEnv ts
                    ; return ((x,schm) : ss)
                    }
      
      

polyEnv env delta [] = return []
polyEnv env delta ((x,xtyp):ts) =
   do { schm <- generalize (gen env delta x) xtyp
      ; ss <- polyEnv env delta ts
      ; return ((x,schm) : ss)
      }
 where gen env delta s r = 
        do { b1 <- generic env r
           ; b2 <- mutualgeneric s delta r
           ; return (b1 && b2) 
           }
                    
mutualgeneric :: String -> [(String,Type a)] -> Tref a -> IM a  Bool        
mutualgeneric s [] r = return True
mutualgeneric s ((name,typ):more) r =
  if name==s 
     then mutualgeneric s more r
     else  do { b <- occursIn1 r typ
              ; if b then return False else mutualgeneric s more r
              }
     
generic :: [(n,Scheme a)] -> Tref a -> IM a  Bool        
generic [] r = return True
generic ((name,Sch _ typ):more) r =
  do { b <- occursIn1 r typ
     ; if b then return False else generic more r
     }
        
--