Theory StateT

Up to index of Isabelle/HOL/Constructor

theory StateT
imports MonadClass
begin

header {* State monad transformer *}

theory StateT
imports MonadClass
begin

subsection {* Type definition and monad operations *}

datatype ('a,'m,'s) stateT =
  stateT "'s => ('a × 's)•'m::tycon"

consts
  run_stateT :: "('a,'m,'s) stateT => 's => ('a × 's)•'m::tycon"

primrec
  "run_stateT (stateT x) = x"

lemma run_stateT_inverse [simp]: "stateT (run_stateT x) = x"
by (induct x, simp)


constdefs
  fmap_stateT :: "('a => 'b) => ('a,'m::functor,'s)stateT => ('b,'m,'s)stateT"
  "fmap_stateT ≡ λf m. stateT (λs.
    fmap (λ(x, s). (f x, s)) (run_stateT m s))"

  return_stateT :: "'a => ('a,'m::monad,'s)stateT"
  "return_stateT ≡ λx. stateT (λs. return (x,s))"

  bind_stateT :: "('a,'m::monad,'s)stateT =>
                  ('a => ('c,'m,'s)stateT) => ('c,'m,'s)stateT"
  "bind_stateT ≡ λm k. stateT (λs.
    do {(x,s') \<leftarrow> run_stateT m s; run_stateT (k x) s'})"

constdefs
  lift_stateT :: "'a•'m => ('a,'m::functor,'s)stateT"
  "lift_stateT ≡ λm. stateT (λs. fmap (λx. (x,s)) m)"

  get_stateT :: "('s,'m::monad,'s)stateT"
  "get_stateT ≡ stateT (λs. return (s, s))"

  set_stateT :: "'s => (unit,'m::monad,'s)stateT"
  "set_stateT ≡ λx. stateT (λs. return ((),x))"

lemma fmap_id_stateT: "fmap_stateT id xs = xs"
by (simp add: fmap_stateT_def split_def)

lemma fmap_fmap_stateT:
  "fmap_stateT f (fmap_stateT g xs) = fmap_stateT (λx. f (g x)) xs"
by (simp add: fmap_stateT_def fmap_fmap split_def)

lemma monad_fmap_stateT:
  "fmap_stateT f xs = bind_stateT xs (λx. return_stateT (f x))"
apply (simp add: fmap_stateT_def bind_stateT_def return_stateT_def)
apply (simp add: monad_fmap split_def)
done

lemma monad_left_unit_stateT: "bind_stateT (return_stateT x) f = f x"
by (simp add: bind_stateT_def return_stateT_def)

lemma monad_bind_assoc_stateT:
  "bind_stateT (bind_stateT xs f) g =
   bind_stateT xs (λx. bind_stateT (f x) g)"
by (simp add: bind_stateT_def monad_bind_assoc split_def)

subsection {* @{text rep} instance *}

instance stateT :: (rep, functor, "{rep,finrep}") rep_consts ..

defs (overloaded)
  emb_stateT_def: "emb ≡ emb o run_stateT o fmap_stateT emb"
  proj_stateT_def: "proj ≡ fmap_stateT proj o stateT o proj"

instance stateT :: (rep, functor, "{rep,finrep}") rep
apply (intro_classes)
apply (unfold emb_stateT_def proj_stateT_def)
apply (simp add: fmap_fmap_stateT fmap_id_stateT [unfolded id_def])
done

subsection {* @{text functor} instance *}

datatype ('m,'s) StateT = StateT

instance StateT :: (functor, "{rep,finrep}") tycon ..
defs (overloaded)
  monotc_StateT_def:
    "monotc (t::('m::functor,'s::{rep,finrep}) StateT itself) ≡ functor_tc
      (fmap_stateT :: (U => U) => (U,'m,'s) stateT => (U,'m,'s) stateT)"

text "StateT is in @{term functor_locale}"

defs (overloaded)
  rep_fmap_StateT_def:
  "rep_fmap::(U => U) =>
     U•('m::functor,'s::{rep,finrep})StateT => U•('m,'s)StateT
     ≡ rep_fmap_of
     (fmap_stateT :: (U => U) => (U,'m,'s) stateT => (U,'m,'s) stateT)"

lemma functor_locale_StateT:
  "functor_locale TYPE(('m::functor,'s::{rep,finrep}) StateT)
     (fmap_stateT :: (U => U) => (U,'m,'s) stateT => (U,'m,'s) stateT)"
apply (rule functor_locale.intro)
apply (rule monotc_StateT_def)
apply (rule rep_fmap_StateT_def)
apply (rule fmap_id_stateT)
apply (rule fmap_fmap_stateT)
done

instance StateT :: (functor, "{rep,finrep}") functor
apply (rule functor_locale.functor_class)
apply (rule functor_locale_StateT)
done

subsection {* StateT is isomorphic to stateT *}

constdefs
  abs_StateT :: "('a,'m::functor,'s::{finrep,rep}) stateT => 'a•('m,'s) StateT"
  "abs_StateT ≡ coerce"

  rep_StateT :: "'a•('m::functor,'s::{finrep,rep}) StateT => ('a,'m,'s) stateT"
  "rep_StateT ≡ coerce"

lemma emb_stateT: "emb xs = emb (fmap_stateT emb xs)"
by (simp add: emb_stateT_def fmap_id_stateT)

lemma proj_stateT: "proj y = fmap_stateT proj (proj y)"
by (simp add: proj_stateT_def fmap_id_stateT)

lemma rep_of_stateT:
  "rep_of TYPE(('a,'m::functor,'s::{finrep,rep}) stateT) =
   rep_of TYPE('a•('m,'s) StateT)"
apply (rule functor_locale.rep_of_App_functor)
apply (rule functor_locale_StateT)
apply (rule emb_stateT)
apply (rule proj_stateT)
apply (rule fmap_fmap_stateT)
done

lemma StateT_iso [simp]:
  "rep_StateT (abs_StateT x) = x"
  "abs_StateT (rep_StateT y) = y"
by (simp_all add: rep_StateT_def abs_StateT_def rep_of_stateT)

subsection {* @{text monad} instance *}

defs (overloaded)
  rep_return_StateT_def:
   "rep_return::U => U•('m::monad,'s::{rep,finrep})StateT
     ≡ rep_return_of (return_stateT::U => (U,'m,'s)stateT)"

  rep_bind_StateT_def:
   "rep_bind::U•('m::monad,'s::{rep,finrep})StateT =>
     (U => U•('m,'s)StateT) => U•('m,'s)StateT
     ≡ rep_bind_of (bind_stateT::
       (U,'m,'s)stateT => (U => (U,'m,'s)stateT) => (U,'m,'s)stateT)"

lemma monad_locale_StateT:
  "monad_locale TYPE(('m::monad,'s::{rep,finrep})StateT)
    fmap_stateT (return_stateT::U => (U,'m,'s)stateT) bind_stateT"
apply (rule monad_locale.intro)
apply (rule functor_locale_StateT)
apply (rule monad_locale_axioms.intro)
apply (rule rep_return_StateT_def)
apply (rule rep_bind_StateT_def)
apply (rule monad_fmap_stateT)
apply (rule monad_left_unit_stateT)
apply (rule monad_bind_assoc_stateT)
done

instance StateT :: (monad, "{rep,finrep}") monad
 apply (rule monad_locale.monad_class)
 apply (rule monad_locale_StateT)
done

subsection {* Other functions *}

constdefs
  lift_StateT :: "'a•'m => 'a•('m::functor,'s::{finrep,rep}) StateT"
  "lift_StateT ≡ λm. abs_StateT (lift_stateT m)"

  get_StateT :: "'s•('m::monad,'s::{finrep,rep}) StateT"
  "get_StateT ≡ abs_StateT get_stateT"

  set_StateT :: "'s => unit•('m::monad,'s::{finrep,rep}) StateT"
  "set_StateT ≡ λx. abs_StateT (set_stateT x)"

lemma fmap_StateT_def:
  "fmap ≡ λf xs. abs_StateT (fmap_stateT f (rep_StateT xs))"
apply (unfold abs_StateT_def rep_StateT_def)
apply (rule functor_locale_fmap_def)
apply (rule functor_locale_StateT)
apply (rule emb_stateT)
apply (rule proj_stateT)
apply (rule fmap_fmap_stateT)+
done

lemma return_StateT_def: "return ≡ λx. abs_StateT (return_stateT x)"
apply (unfold abs_StateT_def)
apply (rule monad_locale_return_def)
apply (rule monad_locale_StateT)
apply (rule emb_stateT)
apply (rule monad_fmap_stateT)
apply (rule monad_left_unit_stateT)
done

lemma bind_StateT_def:
  "bind ≡ λm k. abs_StateT
    (bind_stateT (rep_StateT m) (λx. rep_StateT (k x)))"
apply (unfold abs_StateT_def rep_StateT_def)
apply (rule monad_locale_bind_def)
apply (rule monad_locale_StateT)
apply (rule emb_stateT)
apply (rule proj_stateT)+
apply (rule monad_fmap_stateT)+
apply (rule monad_left_unit_stateT)+
apply (rule monad_bind_assoc_stateT)+
done

lemma "do {u \<leftarrow> set_StateT x; get_StateT} = do {u \<leftarrow> set_StateT x; return x}"
apply (simp add:
  bind_StateT_def return_StateT_def set_StateT_def get_StateT_def)
apply (simp add:
  bind_stateT_def return_stateT_def set_stateT_def get_stateT_def)
done

end

Type definition and monad operations

lemma run_stateT_inverse:

  stateT (run_stateT x) = x

lemma fmap_id_stateT:

  fmap_stateT id xs = xs

lemma fmap_fmap_stateT:

  fmap_stateT f (fmap_stateT g xs) = fmap_stateT (λx. f (g x)) xs

lemma monad_fmap_stateT:

  fmap_stateT f xs = bind_stateT xsx. return_stateT (f x))

lemma monad_left_unit_stateT:

  bind_stateT (return_stateT x) f = f x

lemma monad_bind_assoc_stateT:

  bind_stateT (bind_stateT xs f) g = bind_stateT xsx. bind_stateT (f x) g)

@{text rep} instance

@{text functor} instance

lemma functor_locale_StateT:

  functor_locale TYPE(('m, 's) StateT) fmap_stateT

StateT is isomorphic to stateT

lemma emb_stateT:

  emb xs = emb (fmap_stateT emb xs)

lemma proj_stateT:

  proj y = fmap_stateT proj (proj y)

lemma rep_of_stateT:

  rep_of TYPE(('a, 'm, 's) stateT) = rep_of TYPE('a $ ('m, 's) StateT)

lemma StateT_iso:

  rep_StateT (abs_StateT x) = x
  abs_StateT (rep_StateT y) = y

@{text monad} instance

lemma monad_locale_StateT:

  monad_locale TYPE(('m, 's) StateT) fmap_stateT return_stateT bind_stateT

Other functions

lemma fmap_StateT_def:

  fmap == λf xs. abs_StateT (fmap_stateT f (rep_StateT xs))

lemma return_StateT_def:

  return == λx. abs_StateT (return_stateT x)

lemma bind_StateT_def:

  op >>= == λm k. abs_StateT (bind_stateT (rep_StateT m) (λx. rep_StateT (k x)))

lemma

  do {u <- set_StateT x; get_StateT} = do {u <- set_StateT x; return x}