{- Interpreter for a simple call-by-need language, 
   using a heap to express sharing and support 
   recursive bindings, and an environment to avoid 
   substitution. -}

import Map
import Maybe
import Monad
import Forest
import Trace
import List

type Var = String
type Constr = String
type Prim = String
type Pattern = (Constr,[Var])

data Exp = Var Var
	 | Int Int
	 | Abs Var Exp
         | App Exp Exp
         | Capp Constr [Exp]
	 | Primapp Prim Exp Exp
         | Case Exp [(Pattern,Exp)]
	 | Letrec [(Var,Exp)] Exp
  deriving (Show,Eq)


-- Heap Pointers
type HPtr = Int

-- Values
data Value = 
   VCapp Constr [HPtr]
 | VInt Int
 | VAbs Env Var Exp
 deriving (Show,Eq)

data HEntry = 
   HValue Value
 | HThunk Env Exp
 | HBlackhole
 deriving Show

type Env = Map Var HPtr

-- Heap: Free location supply plus bindings
data Heap = Heap {free :: [HPtr], bindings :: Map HPtr HEntry}
instance Show Heap where
  show (Heap free bindings) = show bindings
hfresh :: Heap -> (HPtr,Heap)
hfresh (Heap (v:vs) bindings) = (v,Heap vs bindings)
hempty :: Heap 
hempty = Heap {free = [0..], bindings = mempty}
hget :: Heap -> HPtr -> HEntry
hget (Heap free bindings) v = mget bindings v
hset :: Heap  -> (HPtr,HEntry) -> Heap 
hset (Heap free bindings) (v,e) = Heap free (mset bindings (v,e))

-- Answers are Values or Error
type A a = Maybe a  
sols :: A a -> [a]
sols = maybeToList
-- library defines (>>=), return, mzero, mplus

-- Monad carries along heap and returns an answer structure.
newtype M a = M (Heap -> A (a,Heap))
instance Monad M where
  (M m1) >>= k = M (\h -> do (a',h') <- m1 h
			     let M m2 = k a' in m2 h') 
  return x = M (\h -> return (x,h))
instance MonadPlus M where
  mzero = M (\ _ -> mzero)
  (M m1) `mplus` (M m2) = M (\h -> m1 h `mplus` m2 h)
fresh :: M HPtr
fresh = M (\ h -> return (hfresh h))
store :: HPtr -> HEntry -> M ()
store p e = M (\ h -> return ((),hset h (p,e)))
fetch :: HPtr -> M HEntry
fetch p = M (\ h -> return (hget h p,h))
run :: M a -> A (a,Heap)
run (M m) =  m hempty

eval :: Env -> Exp -> M Value
eval env (Int i) = return (VInt i)
eval env (Abs x b) = return (VAbs env x b)
eval env (Capp c es) = 
     do ps <- mapM (const fresh) es
	zipWithM_ store ps (map (HThunk env) es)
	return (VCapp c ps)
eval env (App e0 e1) = 
     do VAbs env' x b <- eval env e0
	p1 <- fresh
	store p1 (HThunk env e1)
	let env'' = mset env' (x,p1)
        eval env'' b
eval env (Letrec xes e) = 
     do let (xs,es) = unzip xes
        ps <- mapM (const fresh) xes
	let env' = foldl mset env (zip xs ps)
        zipWithM_ store ps (map (HThunk env') es)
        eval env' e
eval env (Var x) = 
     do let p = mget env x
        h <- fetch p
        case h of
	  HThunk env' e' -> 
	    do store p HBlackhole
	       v' <- eval env' e'
               store p (HValue v')
               return v'
	  HValue v -> return v
	  HBlackhole -> mzero
eval env (Case e pes) = 
     do VCapp c0 ps <- eval env e
	let plookup [] = mzero
            plookup (((c,xs),b):pes) | c == c0   = return (xs,b)
	                             | otherwise = plookup pes
	(xs,b) <- plookup pes
	let env' = foldl mset env (zip xs ps)
	eval env' b
eval env (Primapp p e1 e2) = 
     do v1 <- eval env e1
	v2 <- eval env e2
	return (doPrimapp p v1 v2)

doPrimapp :: Prim -> Value -> Value -> Value
doPrimapp "eq" (VInt i1) (VInt i2) | i1 == i2  = VCapp "True" [] 
	    	 	           | otherwise = VCapp "False" []
doPrimapp "add" (VInt i) (VInt j) = VInt (i+j)
doPrimapp "sub" (VInt i) (VInt j) = VInt (i-j)
doPrimapp "mul" (VInt i) (VInt j) = VInt (i*j)

interp :: Exp -> A (Value,Heap)
interp e = run (eval mempty e)

interp' :: Exp -> [(Value,Heap)]
interp'=  sols . interp

interp'' :: Exp -> [Value]
interp'' = nub . map fst . interp'

{- Examples -}

true = Capp "True" []
false = Capp "False" []

ifthenelse c t f = Case  c [(("True",[]), t),
			    (("False",[]), f)] 


eq0 e = Primapp "eq" e (Int 0)


-- factorial on ints
fact n = Letrec [("f", Abs "x" 
                          (ifthenelse (eq0 (Var "x")) 
				      (Int 1)
				      (Primapp "mul" 
                                           (App (Var "f") (Primapp "sub" (Var "x") (Int 1)))
					   (Var "x"))))]
	        (App (Var "f") (Int n)) 


-- take on lists
mtake = Letrec [("take", Abs "n" 
                           (Abs "l" 
                              (ifthenelse (eq0 (Var "n"))
				 (Capp "Nil" [])
				 (Case (Var "l")
			           [(("Cons",["x","y"]), 
                                          Capp "Cons" [Var "x",
                                                       App (App (Var "take") 
		  		 	                        (Primapp "sub" (Var "n") (Int 1)))
							        (Var "y")])]))))]
	      (Var "take")


-- len on lists
mlen = Letrec [("len", Abs "l" (Case (Var "l")
			   [(("Cons",["_","y"]), 
				    Primapp "add" (App (Var "len") (Var "y")) (Int 1)),
			     (("Nil",[]), Int 0)]))]
	      (Var "len")

-- head on lists
mhead = Abs "l" (Case (Var "l")
		      [(("Cons",["x","_"]), Var "x"),
		       (("Nil",[]), Int (-1))])

-- tail on lists
mtail = Abs "l" (Case (Var "l")
		      [(("Cons",["_","x"]), Var "x"),
		       (("Nil",[]), Int (-1))])

-- some lists
a = App (App mtake (Int 2)) 
        (Capp "Cons" [Int 0, 
                      Capp "Cons" [Int 1, 
                                   Capp "Cons" [Int 2, 
                                                Capp "Cons" [Int 3, Capp "Nil" []]]]])

b = Letrec [("ones", Capp "Cons" [Int 1,Var "ones"])]
            (App (App mtake (Int 10)) (Var "ones"))


-- some observables 

fact10 = interp'' (fact 10)  -- [VInt 3628800]
alen = interp'' (App mlen a) -- [VInt 2]
blen = interp'' (App mlen b) -- [VInt 10]



