header {* State monad transformer *}
theory StateT
imports MonadClass
begin
subsection {* Type definition and monad operations *}
datatype ('a,'m,'s) stateT =
stateT "'s => ('a × 's)•'m::tycon"
consts
run_stateT :: "('a,'m,'s) stateT => 's => ('a × 's)•'m::tycon"
primrec
"run_stateT (stateT x) = x"
lemma run_stateT_inverse [simp]: "stateT (run_stateT x) = x"
by (induct x, simp)
constdefs
fmap_stateT :: "('a => 'b) => ('a,'m::functor,'s)stateT => ('b,'m,'s)stateT"
"fmap_stateT ≡ λf m. stateT (λs.
fmap (λ(x, s). (f x, s)) (run_stateT m s))"
return_stateT :: "'a => ('a,'m::monad,'s)stateT"
"return_stateT ≡ λx. stateT (λs. return (x,s))"
bind_stateT :: "('a,'m::monad,'s)stateT =>
('a => ('c,'m,'s)stateT) => ('c,'m,'s)stateT"
"bind_stateT ≡ λm k. stateT (λs.
do {(x,s') \<leftarrow> run_stateT m s; run_stateT (k x) s'})"
constdefs
lift_stateT :: "'a•'m => ('a,'m::functor,'s)stateT"
"lift_stateT ≡ λm. stateT (λs. fmap (λx. (x,s)) m)"
get_stateT :: "('s,'m::monad,'s)stateT"
"get_stateT ≡ stateT (λs. return (s, s))"
set_stateT :: "'s => (unit,'m::monad,'s)stateT"
"set_stateT ≡ λx. stateT (λs. return ((),x))"
lemma fmap_id_stateT: "fmap_stateT id xs = xs"
by (simp add: fmap_stateT_def split_def)
lemma fmap_fmap_stateT:
"fmap_stateT f (fmap_stateT g xs) = fmap_stateT (λx. f (g x)) xs"
by (simp add: fmap_stateT_def fmap_fmap split_def)
lemma monad_fmap_stateT:
"fmap_stateT f xs = bind_stateT xs (λx. return_stateT (f x))"
apply (simp add: fmap_stateT_def bind_stateT_def return_stateT_def)
apply (simp add: monad_fmap split_def)
done
lemma monad_left_unit_stateT: "bind_stateT (return_stateT x) f = f x"
by (simp add: bind_stateT_def return_stateT_def)
lemma monad_bind_assoc_stateT:
"bind_stateT (bind_stateT xs f) g =
bind_stateT xs (λx. bind_stateT (f x) g)"
by (simp add: bind_stateT_def monad_bind_assoc split_def)
subsection {* @{text rep} instance *}
instance stateT :: (rep, functor, "{rep,finrep}") rep_consts ..
defs (overloaded)
emb_stateT_def: "emb ≡ emb o run_stateT o fmap_stateT emb"
proj_stateT_def: "proj ≡ fmap_stateT proj o stateT o proj"
instance stateT :: (rep, functor, "{rep,finrep}") rep
apply (intro_classes)
apply (unfold emb_stateT_def proj_stateT_def)
apply (simp add: fmap_fmap_stateT fmap_id_stateT [unfolded id_def])
done
subsection {* @{text functor} instance *}
datatype ('m,'s) StateT = StateT
instance StateT :: (functor, "{rep,finrep}") tycon ..
defs (overloaded)
monotc_StateT_def:
"monotc (t::('m::functor,'s::{rep,finrep}) StateT itself) ≡ functor_tc
(fmap_stateT :: (U => U) => (U,'m,'s) stateT => (U,'m,'s) stateT)"
text "StateT is in @{term functor_locale}"
defs (overloaded)
rep_fmap_StateT_def:
"rep_fmap::(U => U) =>
U•('m::functor,'s::{rep,finrep})StateT => U•('m,'s)StateT
≡ rep_fmap_of
(fmap_stateT :: (U => U) => (U,'m,'s) stateT => (U,'m,'s) stateT)"
lemma functor_locale_StateT:
"functor_locale TYPE(('m::functor,'s::{rep,finrep}) StateT)
(fmap_stateT :: (U => U) => (U,'m,'s) stateT => (U,'m,'s) stateT)"
apply (rule functor_locale.intro)
apply (rule monotc_StateT_def)
apply (rule rep_fmap_StateT_def)
apply (rule fmap_id_stateT)
apply (rule fmap_fmap_stateT)
done
instance StateT :: (functor, "{rep,finrep}") functor
apply (rule functor_locale.functor_class)
apply (rule functor_locale_StateT)
done
subsection {* StateT is isomorphic to stateT *}
constdefs
abs_StateT :: "('a,'m::functor,'s::{finrep,rep}) stateT => 'a•('m,'s) StateT"
"abs_StateT ≡ coerce"
rep_StateT :: "'a•('m::functor,'s::{finrep,rep}) StateT => ('a,'m,'s) stateT"
"rep_StateT ≡ coerce"
lemma emb_stateT: "emb xs = emb (fmap_stateT emb xs)"
by (simp add: emb_stateT_def fmap_id_stateT)
lemma proj_stateT: "proj y = fmap_stateT proj (proj y)"
by (simp add: proj_stateT_def fmap_id_stateT)
lemma rep_of_stateT:
"rep_of TYPE(('a,'m::functor,'s::{finrep,rep}) stateT) =
rep_of TYPE('a•('m,'s) StateT)"
apply (rule functor_locale.rep_of_App_functor)
apply (rule functor_locale_StateT)
apply (rule emb_stateT)
apply (rule proj_stateT)
apply (rule fmap_fmap_stateT)
done
lemma StateT_iso [simp]:
"rep_StateT (abs_StateT x) = x"
"abs_StateT (rep_StateT y) = y"
by (simp_all add: rep_StateT_def abs_StateT_def rep_of_stateT)
subsection {* @{text monad} instance *}
defs (overloaded)
rep_return_StateT_def:
"rep_return::U => U•('m::monad,'s::{rep,finrep})StateT
≡ rep_return_of (return_stateT::U => (U,'m,'s)stateT)"
rep_bind_StateT_def:
"rep_bind::U•('m::monad,'s::{rep,finrep})StateT =>
(U => U•('m,'s)StateT) => U•('m,'s)StateT
≡ rep_bind_of (bind_stateT::
(U,'m,'s)stateT => (U => (U,'m,'s)stateT) => (U,'m,'s)stateT)"
lemma monad_locale_StateT:
"monad_locale TYPE(('m::monad,'s::{rep,finrep})StateT)
fmap_stateT (return_stateT::U => (U,'m,'s)stateT) bind_stateT"
apply (rule monad_locale.intro)
apply (rule functor_locale_StateT)
apply (rule monad_locale_axioms.intro)
apply (rule rep_return_StateT_def)
apply (rule rep_bind_StateT_def)
apply (rule monad_fmap_stateT)
apply (rule monad_left_unit_stateT)
apply (rule monad_bind_assoc_stateT)
done
instance StateT :: (monad, "{rep,finrep}") monad
apply (rule monad_locale.monad_class)
apply (rule monad_locale_StateT)
done
subsection {* Other functions *}
constdefs
lift_StateT :: "'a•'m => 'a•('m::functor,'s::{finrep,rep}) StateT"
"lift_StateT ≡ λm. abs_StateT (lift_stateT m)"
get_StateT :: "'s•('m::monad,'s::{finrep,rep}) StateT"
"get_StateT ≡ abs_StateT get_stateT"
set_StateT :: "'s => unit•('m::monad,'s::{finrep,rep}) StateT"
"set_StateT ≡ λx. abs_StateT (set_stateT x)"
lemma fmap_StateT_def:
"fmap ≡ λf xs. abs_StateT (fmap_stateT f (rep_StateT xs))"
apply (unfold abs_StateT_def rep_StateT_def)
apply (rule functor_locale_fmap_def)
apply (rule functor_locale_StateT)
apply (rule emb_stateT)
apply (rule proj_stateT)
apply (rule fmap_fmap_stateT)+
done
lemma return_StateT_def: "return ≡ λx. abs_StateT (return_stateT x)"
apply (unfold abs_StateT_def)
apply (rule monad_locale_return_def)
apply (rule monad_locale_StateT)
apply (rule emb_stateT)
apply (rule monad_fmap_stateT)
apply (rule monad_left_unit_stateT)
done
lemma bind_StateT_def:
"bind ≡ λm k. abs_StateT
(bind_stateT (rep_StateT m) (λx. rep_StateT (k x)))"
apply (unfold abs_StateT_def rep_StateT_def)
apply (rule monad_locale_bind_def)
apply (rule monad_locale_StateT)
apply (rule emb_stateT)
apply (rule proj_stateT)+
apply (rule monad_fmap_stateT)+
apply (rule monad_left_unit_stateT)+
apply (rule monad_bind_assoc_stateT)+
done
lemma "do {u \<leftarrow> set_StateT x; get_StateT} = do {u \<leftarrow> set_StateT x; return x}"
apply (simp add:
bind_StateT_def return_StateT_def set_StateT_def get_StateT_def)
apply (simp add:
bind_stateT_def return_stateT_def set_stateT_def get_stateT_def)
done
end
lemma run_stateT_inverse:
stateT (run_stateT x) = x
lemma fmap_id_stateT:
fmap_stateT id xs = xs
lemma fmap_fmap_stateT:
fmap_stateT f (fmap_stateT g xs) = fmap_stateT (λx. f (g x)) xs
lemma monad_fmap_stateT:
fmap_stateT f xs = bind_stateT xs (λx. return_stateT (f x))
lemma monad_left_unit_stateT:
bind_stateT (return_stateT x) f = f x
lemma monad_bind_assoc_stateT:
bind_stateT (bind_stateT xs f) g = bind_stateT xs (λx. bind_stateT (f x) g)
lemma functor_locale_StateT:
functor_locale TYPE(('m, 's) StateT) fmap_stateT
lemma emb_stateT:
emb xs = emb (fmap_stateT emb xs)
lemma proj_stateT:
proj y = fmap_stateT proj (proj y)
lemma rep_of_stateT:
rep_of TYPE(('a, 'm, 's) stateT) = rep_of TYPE('a $ ('m, 's) StateT)
lemma StateT_iso:
rep_StateT (abs_StateT x) = x
abs_StateT (rep_StateT y) = y
lemma monad_locale_StateT:
monad_locale TYPE(('m, 's) StateT) fmap_stateT return_stateT bind_stateT
lemma fmap_StateT_def:
fmap == λf xs. abs_StateT (fmap_stateT f (rep_StateT xs))
lemma return_StateT_def:
return == λx. abs_StateT (return_stateT x)
lemma bind_StateT_def:
op >>= == λm k. abs_StateT (bind_stateT (rep_StateT m) (λx. rep_StateT (k x)))
lemma
do {u <- set_StateT x; get_StateT} = do {u <- set_StateT x; return x}