Theory MonadClass

Up to index of Isabelle/HOL/Constructor

theory MonadClass
imports FunctorClass
begin

header {* Monad Class *}

theory MonadClass
imports FunctorClass
begin

consts
  rep_return :: "U => U•'m::tycon"
  rep_bind :: "U•'m => (U => U•'m) => U•'m::tycon"

axclass monad < functor, tycon
  monad_rep_return_type:
    "x ::: A ==> emb ((rep_return x)::U•'m::tycon) ::: tc TYPE('m) A"

  monad_rep_bind_type:
    "[|emb (m::U•'m::tycon) ::: tc TYPE('m) A;
       !!x. x ::: A ==> emb (f x) ::: tc TYPE('m) B|]
      ==> emb (rep_bind m f) ::: tc TYPE('m) B"

  monad_rep_fmap:
    "rep_fmap f m = rep_bind m (λx. rep_return (f x))"

  monad_rep_1:
    "rep_bind (rep_return x) f = f x"

  monad_rep_3:
    "rep_bind (rep_bind m f) g = rep_bind m (λx. rep_bind (f x) g)"

lemmas rep_return_type = monad_rep_return_type
    [of _ "rep_of TYPE('b)", simplified tc_rep_of, standard]

lemmas rep_bind_type = monad_rep_bind_type
    [of _ "rep_of TYPE('b)" _ "rep_of TYPE('c)",
       simplified tc_rep_of, standard]

constdefs
  return :: "'a => 'a•'m::monad"
  "return ≡ λx. coerce (rep_return (emb x)::U•'m)"

constdefs
  bind :: "'a•'m::monad => ('a => 'b•'m) => 'b•'m" (infixl ">>=" 55)
  "bind ≡ λm k. coerce (rep_bind (coerce m::U•'m) (coerce o k o proj))"

syntax (xsymbols)
  bind :: "'a•'m::monad => ('a => 'b•'m) => 'b•'m" (infixl "\<triangleright>" 55)

text {* do notation *}

nonterminals
  do_bind do_binds

syntax
  "_bind" :: "[pttrn, 'a•'m::monad] => do_bind" ("(2_ <- _)" 10)
  "_binds" :: "[do_bind, do_binds] => do_binds" ("_;/ _")
  "" :: "do_bind => do_binds" ("_")
  "_do" :: "[do_binds, 'a•'m::monad] => 'a•'m" ("(do {_; (_)})")

syntax (xsymbols)
  "_bind" :: "pttrn => 'a•'m::monad => do_bind" ("(2_ \<leftarrow> _)" 10)

translations
  "_do (_binds b bs) e" == "_do b (_do bs e)"
  "_do (_bind x m) k" == "bind m (λx. k)"


text {* monad laws *}

theorem monad_fmap:
  "fmap f xs = xs \<triangleright> (λx. return (f x))"
 apply (simp add: fmap_def bind_def return_def o_def)
 apply (simp add: monad_rep_fmap)
 apply (rule_tac f=coerce in arg_cong)
 apply (rule_tac f="rep_bind (coerce xs)" in arg_cong)
 apply (rule ext)
 apply (rule coerce_inverse[symmetric])
 apply (simp add: rep_return_type)
done

theorem monad_left_unit [simp]: "(return x \<triangleright> f) = (f x)"
 apply (simp add: bind_def return_def o_def)
 apply (subst coerce_inverse)
  apply (simp add: rep_return_type)
 apply (simp add: monad_rep_1)
done

theorem monad_right_unit [simp]: "(m \<triangleright> return) = m"
 apply (subgoal_tac "fmap id m = m")
  apply (simp only: monad_fmap)
  apply simp
 apply simp
done

theorem monad_bind_assoc: "((m \<triangleright> f) \<triangleright> g) = (m \<triangleright> (λx. f x \<triangleright> g))"
 apply (simp add: bind_def o_def)
 apply (subst coerce_inverse)
  apply (rule rep_bind_type)
   apply (rule coerce_type, simp)
  apply (simp add: coerce_type)
 apply (simp add: monad_rep_3)
 apply (rule_tac f=coerce in arg_cong)
 apply (rule_tac f="rep_bind (coerce m)" in arg_cong)
 apply (rule ext)
 apply (rule coerce_inverse[symmetric])
 apply (rule rep_bind_type)
  apply (rule coerce_type, simp)
 apply (simp add: coerce_type)
done


text {* laws for fmap *}

lemma fmap_return: "fmap f (return x) = return (f x)"
by (simp add: monad_fmap)

lemma fmap_bind: "fmap f (bind m k) = bind m (λx. fmap f (k x))"
by (simp add: monad_fmap monad_bind_assoc)

lemma bind_fmap: "bind (fmap f m) k = bind m (λx. k (f x))"
by (simp add: monad_fmap monad_bind_assoc)

lemma congruent_bind: "(∀m. m \<triangleright> k1 = m \<triangleright> k2) = (k1 = k2)"
 apply (safe, rule ext)
 apply (drule_tac x="return x" in spec, simp)
done

text {* laws for join *}

constdefs
  join :: "('a•'m::monad)•'m => 'a•'m"
  "join ≡ λm. m \<triangleright> (λx. x)"

lemma join_fmap_fmap: "join (fmap (fmap f) xss) = fmap f (join xss)"
by (simp add: join_def monad_fmap monad_bind_assoc)

lemma join_return: "join (return xs) = xs"
by (simp add: join_def)

lemma join_fmap_return: "join (fmap return xs) = xs"
by (simp add: join_def monad_fmap monad_bind_assoc)

lemma join_fmap_join: "join (fmap join xsss) = join (join xsss)"
by (simp add: join_def monad_fmap monad_bind_assoc)

lemma bind_def2: "m \<triangleright> k = join (fmap k m)"
by (simp add: join_def monad_fmap monad_bind_assoc)


text {* equivalence of monad laws and fmap/join laws *}

lemma "(return x \<triangleright> f) = (f x)"
by (simp only: bind_def2 fmap_return join_return)

lemma "(m \<triangleright> return) = m"
by (simp only: bind_def2 join_fmap_return)

lemma "((m \<triangleright> f) \<triangleright> g) = (m \<triangleright> (λx. f x \<triangleright> g))"
 apply (simp only: bind_def2)
 apply (subgoal_tac "join (fmap g (join (fmap f m))) =
    join (fmap join (fmap (fmap g) (fmap f m)))")
  apply (simp add: fmap_fmap)
 apply (simp add: join_fmap_join join_fmap_fmap)
done


subsection {*locale for proving stuff about monad representations *}

constdefs
  rep_return_of :: "(U => 'l) => (U => U•'f::tycon)"
  "rep_return_of uret ≡ coerce o uret"

  rep_bind_of :: "('l => (U => 'l) => 'l) =>
                  (U•'f => (U => U•'f) => U•'f::tycon)"
  "rep_bind_of ubind ≡ λm k. coerce (ubind (coerce m) (coerce o k))"

locale monad_locale = functor_locale +
  fixes uret :: "U => 'l"
    and ubind :: "'l => (U => 'l) => 'l"
assumes rep_return: "rep_return :: U => U•'f::tycon
                      ≡ rep_return_of uret"
    and rep_bind: "rep_bind :: U•'f => (U => U•'f) => U•'f::tycon
                      ≡ rep_bind_of ubind"
    and umap_def: "!!f xs. umap f xs = ubind xs (λx. uret (f x))"
    and m1: "!!f x. ubind (uret x) f = f x"
    and m3: "!!xs f g. ubind (ubind xs f) g = ubind xs (λx. ubind (f x) g)"

lemma (in monad_locale) umap_uret: "umap f (uret x) = uret (f x)"
by (simp add: umap_def m1)

lemma (in monad_locale) ubind_umap:
"ubind (umap f xs) k = ubind xs (λx. k (f x))"
by (simp add: umap_def m3 m1)

lemma (in monad_locale) umap_ubind:
"umap f (ubind xs k) = ubind xs (λx. umap f (k x))"
by (simp add: umap_def m3)

lemma (in monad_locale) return_type:
  "x ::: A ==> emb ((rep_return x)::U•'f::tycon)
       ::: tc TYPE('f) A"
 apply (simp add: rep_return rep_return_of_def)
 apply (subst emb_in_tc)
 apply (simp add: umap_uret)
 apply (simp add: cast_fixed)
done

lemma (in monad_locale) bind_type:
assumes Pf: "!!x. x ::: A ==> emb (f x) ::: tc TYPE('f::tycon) B"
shows "emb (m::U•'f) ::: tc TYPE('f::tycon) A
   ==> emb (rep_bind m f) ::: tc TYPE('f) B"
 apply (simp add: rep_bind rep_bind_of_def)
 apply (simp add: emb_in_tc emb_in_tc_2)
 apply (erule subst)
 apply (simp add: ubind_umap umap_ubind)
 apply (rule_tac f="ubind (coerce m)" in arg_cong)
 apply (rule ext)
 apply (rule emb_in_tc [THEN iffD1])
 apply simp
 apply (rule Pf [OF cast_in_idem])
done

lemma (in monad_locale) monad_fmap:
fixes m :: "U•'f::tycon"
shows "rep_fmap f m = rep_bind m (λx. rep_return (f x))"
 apply (simp add: rep_fmap rep_bind rep_return)
 apply (simp add: rep_fmap_of_def rep_bind_of_def rep_return_of_def)
 apply (simp only: umap_def)
 apply (simp add: o_def)
done

lemma (in monad_locale) monad1:
fixes f :: "U => U•'f::tycon"
shows "rep_bind (rep_return x) f = f x"
by (simp add: rep_bind rep_return rep_bind_of_def rep_return_of_def m1)

lemma (in monad_locale) monad3:
fixes f :: "U => U•'f::tycon"
  and g :: "U => U•'f::tycon"
shows "rep_bind (rep_bind m f) g
        = rep_bind m (λx. rep_bind (f x) g)"
 apply (simp add: rep_bind rep_bind_of_def)
 apply (subst m3)
 apply (simp add: o_def)
done

text {* alternate introduction rule for monad-locale *}

lemma functor_locale_intro2:
  fixes umap :: "(U => U) => 'l => 'l"
    and uret :: "U => 'l"
    and ubind :: "'l => (U => 'l) => 'l"
assumes tc_umap: "monotc TYPE('f::tycon) ≡ functor_tc umap"
    and rep_fmap:  "rep_fmap :: (U => U) => U•'f => U•'f::tycon
                      ≡ rep_fmap_of umap"
    and umap_def: "!!f xs. umap f xs = ubind xs (λx. uret (f x))"
    and m1: "!!f x. ubind (uret x) f = f x"
    and m2: "!!xs. ubind xs uret = xs"
    and m3: "!!xs f g. ubind (ubind xs f) g = ubind xs (λx. ubind (f x) g)"
shows "functor_locale TYPE('f::tycon) umap"
 apply (rule functor_locale.intro)
    apply (rule tc_umap)
   apply (rule rep_fmap)
  apply (simp add: umap_def m2)
 apply (simp add: umap_def m3 m1)
done

lemma monad_locale_intro2:
  fixes umap :: "(U => U) => 'l => 'l"
    and uret :: "U => 'l"
    and ubind :: "'l => (U => 'l) => 'l"
assumes tc_umap: "monotc TYPE('f::tycon) ≡ functor_tc umap"
    and rep_fmap: "rep_fmap :: (U => U) => U•'f => U•'f::tycon
                      ≡ rep_fmap_of umap"
    and rep_return: "rep_return :: U => U•'f::tycon
                      ≡ rep_return_of uret"
    and rep_bind: "rep_bind :: U•'f => (U => U•'f) => U•'f::tycon
                      ≡ rep_bind_of ubind"
    and umap_def: "!!f xs. umap f xs = ubind xs (λx. uret (f x))"
    and m1: "!!f x. ubind (uret x) f = f x"
    and m2: "!!xs. ubind xs uret = xs"
    and m3: "!!xs f g. ubind (ubind xs f) g = ubind xs (λx. ubind (f x) g)"
shows "monad_locale TYPE('f::tycon) umap uret ubind"
 apply (rule monad_locale.intro)
  apply (rule functor_locale_intro2, assumption+)
 apply (rule monad_locale_axioms.intro, assumption+)
done

lemma (in monad_locale) monad_class:
  "OFCLASS('f::tycon, monad_class)"
apply (intro_classes)
apply (rule type, fast, assumption)
apply (rule cast)
apply (rule comp)
apply (erule return_type)
apply (rule bind_type, fast, assumption)
apply (rule monad_fmap)
apply (rule monad1)
apply (rule monad3)
done

text {* Other type instances *}

lemmas ml2fl = monad_locale.axioms(1)

lemma monad_locale_return_def:
  fixes return_a :: "'a => 'k"
  fixes return_U :: "U => 'l"
  fixes bind_UU :: "'l => (U => 'l) => 'l"
  fixes map_aU :: "('a => U) => 'k => 'l"
  fixes map_UU :: "(U => U) => 'l => 'l"
  assumes ml: "monad_locale TYPE('f::monad) map_UU return_U bind_UU"
  assumes rews:
    "!!xs. emb xs = emb (map_aU emb xs)"
    "!!f xs. map_aU f xs = bind_aU xs (λx. return_U (f x))"
    "!!f x. bind_aU (return_a x) f = f x"
  shows "return ≡ λx. (coerce::'k => 'a•'f::monad) (return_a x)"
 apply (rule eq_reflection, rule ext)
 apply (unfold return_def)
 apply (simp add: ml [THEN monad_locale.rep_return])
 apply (simp add: rep_return_of_def)
 apply (simp add: fmap_coerce [symmetric])
 apply (simp add: fmap_def)
 apply (simp add: ml [THEN ml2fl [THEN functor_locale.rep_fmap]])
 apply (simp add: rep_fmap_of_def)
 apply (simp add: ml [THEN ml2fl [THEN functor_locale.rep_of_App_U_functor]])
 apply (simp add: ml [THEN monad_locale.umap_def])
 apply (simp add: ml [THEN monad_locale.m1])
 apply (simp add: coerce_def rews)
done

lemma monad_locale_bind_def:
  fixes return_U :: "U => 'l"
  fixes return_a :: "'a => 'k"
  fixes return_b :: "'b => 'm"
  fixes bind_UU :: "'l => (U => 'l) => 'l"
  fixes bind_aU :: "'k => ('a => 'l) => 'l"
  fixes bind_bU :: "'m => ('b => 'l) => 'l"
  fixes bind_Ub :: "'l => (U => 'm) => 'm"
  fixes bind_ab :: "'k => ('a => 'm) => 'm"
  fixes map_Ua :: "(U => 'a) => 'l => 'k"
  fixes map_aU :: "('a => U) => 'k => 'l"
  fixes map_bU :: "('b => U) => 'm => 'l"
  fixes map_Ub :: "(U => 'b) => 'l => 'm"
  fixes map_UU :: "(U => U) => 'l => 'l"
  assumes ml: "monad_locale TYPE('f::monad) map_UU return_U bind_UU"
  assumes rews:
    "!!xs. emb xs = emb (map_bU emb xs)"
    "!!u. proj u = map_Ua proj (proj u)"
    "!!u. proj u = map_Ub proj (proj u)"
    "!!f xs. map_bU f xs = bind_bU xs (λx. return_U (f x))"
    "!!f xs. map_Ua f xs = bind_Ua xs (λx. return_a (f x))"
    "!!f xs. map_Ub f xs = bind_Ub xs (λx. return_b (f x))"
    "!!x k. bind_aU (return_a x) k = k x"
    "!!x k. bind_bU (return_b x) k = k x"
    "!!xs f g. bind_bU (bind_ab xs f) g = bind_aU xs (λx. bind_bU (f x) g)"
    "!!xs f g. bind_aU (bind_Ua xs f) g = bind_UU xs (λx. bind_aU (f x) g)"
    "!!xs f g. bind_bU (bind_Ub xs f) g = bind_UU xs (λx. bind_bU (f x) g)"
  shows "bind ≡ λ(m::'a•'f::monad) (k::'a => 'b•'f).
          coerce (bind_ab (coerce m) (λx. coerce (k x)))"
 apply (rule eq_reflection, rule ext, rule ext)
 apply (unfold bind_def)
 apply (simp add: ml [THEN monad_locale.rep_bind])
 apply (simp add: rep_bind_of_def)
 apply (simp add: fmap_coerce [symmetric])
 apply (simp add: fmap_def o_def)
 apply (simp add: ml [THEN ml2fl [THEN functor_locale.rep_fmap]])
 apply (simp add: rep_fmap_of_def)
 apply (simp add: ml [THEN ml2fl [THEN functor_locale.rep_of_App_U_functor]])
 apply (simp add: ml [THEN monad_locale.umap_ubind])
 apply (simp add: ml [THEN ml2fl [THEN functor_locale.umap_umap]])
 apply (simp add: ml [THEN monad_locale.umap_def])
 apply (simp add: coerce_def rews)
done

end

lemmas rep_return_type:

  x ::: rep_of TYPE('b) ==> emb (rep_return x) ::: rep_of TYPE('b $ 'a)

lemmas rep_return_type:

  x ::: rep_of TYPE('b) ==> emb (rep_return x) ::: rep_of TYPE('b $ 'a)

lemmas rep_bind_type:

  [| emb m ::: rep_of TYPE('b $ 'a);
     !!x. x ::: rep_of TYPE('b) ==> emb (f x) ::: rep_of TYPE('c $ 'a) |]
  ==> emb (rep_bind m f) ::: rep_of TYPE('c $ 'a)

lemmas rep_bind_type:

  [| emb m ::: rep_of TYPE('b $ 'a);
     !!x. x ::: rep_of TYPE('b) ==> emb (f x) ::: rep_of TYPE('c $ 'a) |]
  ==> emb (rep_bind m f) ::: rep_of TYPE('c $ 'a)

theorem monad_fmap:

  fmap f xs = do {x <- xs; return (f x)}

theorem monad_left_unit:

  return x >>= f = f x

theorem monad_right_unit:

  m >>= return = m

theorem monad_bind_assoc:

  m >>= f >>= g = do {x <- m; f x >>= g}

lemma fmap_return:

  fmap f (return x) = return (f x)

lemma fmap_bind:

  fmap f (m >>= k) = do {x <- m; fmap f (k x)}

lemma bind_fmap:

  fmap f m >>= k = do {x <- m; k (f x)}

lemma congruent_bind:

  (∀m. m >>= k1.0 = m >>= k2.0) = (k1.0 = k2.0)

lemma join_fmap_fmap:

  MonadClass.join (fmap (fmap f) xss) = fmap f (MonadClass.join xss)

lemma join_return:

  MonadClass.join (return xs) = xs

lemma join_fmap_return:

  MonadClass.join (fmap return xs) = xs

lemma join_fmap_join:

  MonadClass.join (fmap MonadClass.join xsss) =
  MonadClass.join (MonadClass.join xsss)

lemma bind_def2:

  m >>= k = MonadClass.join (fmap k m)

lemma

  return x >>= f = f x

lemma

  m >>= return = m

lemma

  m >>= f >>= g = do {x <- m; f x >>= g}

locale for proving stuff about monad representations

lemma umap_uret:

  monad_locale TYPE('f) umap uret ubind ==> umap f (uret x) = uret (f x)

lemma ubind_umap:

  monad_locale TYPE('f) umap uret ubind
  ==> ubind (umap f xs) k = ubind xsx. k (f x))

lemma umap_ubind:

  monad_locale TYPE('f) umap uret ubind
  ==> umap f (ubind xs k) = ubind xsx. umap f (k x))

lemma return_type:

  [| monad_locale TYPE('f) umap uret ubind; x ::: A |]
  ==> emb (rep_return x) ::: tc TYPE('f) A

lemma bind_type:

  [| monad_locale TYPE('f) umap uret ubind;
     !!x. x ::: A ==> emb (f x) ::: tc TYPE('f) B; emb m ::: tc TYPE('f) A |]
  ==> emb (rep_bind m f) ::: tc TYPE('f) B

lemma monad_fmap:

  monad_locale TYPE('f) umap uret ubind
  ==> rep_fmap f m = rep_bind mx. rep_return (f x))

lemma monad1:

  monad_locale TYPE('f) umap uret ubind ==> rep_bind (rep_return x) f = f x

lemma monad3:

  monad_locale TYPE('f) umap uret ubind
  ==> rep_bind (rep_bind m f) g = rep_bind mx. rep_bind (f x) g)

lemma functor_locale_intro2:

  [| monotc TYPE('f) == functor_tc umap; rep_fmap == rep_fmap_of umap;
     !!f xs. umap f xs = ubind xsx. uret (f x)); !!f x. ubind (uret x) f = f x;
     !!xs. ubind xs uret = xs;
     !!xs f g. ubind (ubind xs f) g = ubind xsx. ubind (f x) g) |]
  ==> functor_locale TYPE('f) umap

lemma monad_locale_intro2:

  [| monotc TYPE('f) == functor_tc umap; rep_fmap == rep_fmap_of umap;
     rep_return == rep_return_of uret; rep_bind == rep_bind_of ubind;
     !!f xs. umap f xs = ubind xsx. uret (f x)); !!f x. ubind (uret x) f = f x;
     !!xs. ubind xs uret = xs;
     !!xs f g. ubind (ubind xs f) g = ubind xsx. ubind (f x) g) |]
  ==> monad_locale TYPE('f) umap uret ubind

lemma monad_class:

  monad_locale TYPE('f) umap uret ubind ==> OFCLASS('f, monad_class)

lemmas ml2fl:

  monad_locale TYPE('f) umap uret ubind ==> functor_locale TYPE('f) umap

lemmas ml2fl:

  monad_locale TYPE('f) umap uret ubind ==> functor_locale TYPE('f) umap

lemma monad_locale_return_def:

  [| monad_locale TYPE('f) map_UU return_U bind_UU;
     !!xs. emb xs = emb (map_aU emb xs);
     !!f xs. map_aU f xs = bind_aU xsx. return_U (f x));
     !!f x. bind_aU (return_a x) f = f x |]
  ==> return == λx. coerce (return_a x)

lemma monad_locale_bind_def:

  [| monad_locale TYPE('f) map_UU return_U bind_UU;
     !!xs. emb xs = emb (map_bU emb xs); !!u. proj u = map_Ua proj (proj u);
     !!u. proj u = map_Ub proj (proj u);
     !!f xs. map_bU f xs = bind_bU xsx. return_U (f x));
     !!f xs. map_Ua f xs = bind_Ua xsx. return_a (f x));
     !!f xs. map_Ub f xs = bind_Ub xsx. return_b (f x));
     !!x k. bind_aU (return_a x) k = k x; !!x k. bind_bU (return_b x) k = k x;
     !!xs f g. bind_bU (bind_ab xs f) g = bind_aU xsx. bind_bU (f x) g);
     !!xs f g. bind_aU (bind_Ua xs f) g = bind_UU xsx. bind_aU (f x) g);
     !!xs f g. bind_bU (bind_Ub xs f) g = bind_UU xsx. bind_bU (f x) g) |]
  ==> op >>= == λm k. coerce (bind_ab (coerce m) (λx. coerce (k x)))