module SourceMonad(cMonadPlus, mzeroMPfun, mplusMPfun, instsMonadPlus,
                   monadDefns) where

import Testbed
import HaskellPrims
import HaskellPrelude

-----------------------------------------------------------------------------
-- Test Framework:

main      :: IO ()
main       = test imports monadDefns

saveMonad :: IO ()
saveMonad  = save "HaskellMonad" imports monadDefns

imports   :: [Assump]
imports    = defnsHaskellPrims ++ defnsHaskellPrelude

-----------------------------------------------------------------------------
-- Test Program:

cMonadPlus = Class { name = "MonadPlus",
                     super = [cMonad],
                     insts = instsMonadPlus }

mzeroMPfun
 = "mzero" :>: (Forall [Kfun Star Star, Star]
                 ([isIn1 cMonadPlus (TGen 0)] :=>
                  (TAp (TGen 0) (TGen 1))))

mplusMPfun
 = "mplus" :>: (Forall [Kfun Star Star, Star]
                 ([isIn1 cMonadPlus (TGen 0)] :=>
                  (ma `fn` ma `fn` ma)))
   where ma = TAp (TGen 0) (TGen 1)

instsMonadPlus
 = [mkInst []
     ([] :=>
      isIn1 cMonadPlus tMaybe),
    mkInst []
     ([] :=>
      isIn1 cMonadPlus tList)]

monadDefns :: [BindGroup]
monadDefns
 = map toBg
   [[("msum",
      Just (Forall [Kfun Star Star, Star]
             ([isIn1 cMonadPlus (TGen 0)] :=>
              (TAp tList (TAp (TGen 0) (TGen 1))
               `fn` TAp (TGen 0) (TGen 1)))),
      [([],
        ap [evar "foldr", econst mplusMPfun, econst mzeroMPfun])])],
    [("join",
      Just (Forall [Kfun Star Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              (TAp (TGen 0) (TAp (TGen 0) (TGen 1))
               `fn` TAp (TGen 0) (TGen 1)))),
      [([PVar "x"],
        ap [econst mbindMfun, evar "x", evar "id"])])],
    [("when",
      Just (Forall [Kfun Star Star]
             ([isIn1 cMonad (TGen 0)] :=>
              (tBool
               `fn` TAp (TGen 0) tUnit
               `fn` TAp (TGen 0) tUnit))),
      [([PVar "p", PVar "s"],
        If (evar "p") (evar "s") (ap [econst returnMfun, econst unitCfun]))])],
    [("unless",
      Just (Forall [Kfun Star Star]
             ([isIn1 cMonad (TGen 0)] :=>
              (tBool
               `fn` TAp (TGen 0) tUnit
               `fn` TAp (TGen 0) tUnit))),
      [([PVar "p", PVar "s"],
        ap [evar "when", ap [evar "not", evar "p"], evar "s"])])],
    [("guard",
      Just (Forall [Kfun Star Star]
             ([isIn1 cMonadPlus (TGen 0)] :=>
              (tBool
               `fn` TAp (TGen 0) tUnit))),
      [([PVar "p"],
        If (evar "p")
           (ap [econst returnMfun, econst unitCfun])
           (econst mzeroMPfun))])],
    [("mapAndUnzipM",
      Just (Forall [Kfun Star Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((TGen 1 `fn` TAp (TGen 0) (TAp (TAp tTuple2 (TGen 2)) (TGen 3)))
               `fn` TAp tList (TGen 1)
               `fn` TAp (TGen 0)
                        (TAp (TAp tTuple2 (TAp tList (TGen 2)))
                             (TAp tList (TGen 3)))))),
      [([PVar "f", PVar "xs"],
        ap [econst mbindMfun,
            ap [evar "sequence", ap [evar "map", evar "f", evar "xs"]],
            ap [evar ".", econst returnMfun, evar "unzip"]])])],
    [("zipWithM",
      Just (Forall [Kfun Star Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((TGen 1 `fn` TGen 2 `fn` TAp (TGen 0) (TGen 3))
               `fn` TAp tList (TGen 1)
               `fn` TAp tList (TGen 2)
               `fn` TAp (TGen 0) (TAp tList (TGen 3))))),
      [([PVar "f", PVar "xs", PVar "ys"],
        ap [evar "sequence",
            ap [evar "zipWith", evar "f", evar "xs", evar "ys"]])])],
    [("zipWithM_",
      Just (Forall [Kfun Star Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((TGen 1 `fn` TGen 2 `fn` TAp (TGen 0) (TGen 3))
               `fn` TAp tList (TGen 1)
               `fn` TAp tList (TGen 2)
               `fn` TAp (TGen 0) tUnit))),
      [([PVar "f", PVar "xs", PVar "ys"],
        ap [evar "sequence_",
            ap [evar "zipWith", evar "f", evar "xs", evar "ys"]])])],
    [("foldM",
      Just (Forall [Kfun Star Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((TGen 1 `fn` TGen 2 `fn` TAp (TGen 0) (TGen 1))
               `fn` TGen 1
               `fn` TAp tList (TGen 2)
               `fn` TAp (TGen 0) (TGen 1)))),
      [([PVar "f", PVar "a", PCon nilCfun []],
        ap [econst returnMfun, evar "a"]),
       ([PVar "f", PVar "a", PCon consCfun [PVar "x", PVar "xs"]],
        ap [econst mbindMfun,
            ap [evar "f", evar "a", evar "x"],
            Lam ([PVar "y"],
                 ap [evar "foldM", evar "f", evar "y", evar "xs"])])])],
    [("filterM",
      Just (Forall [Kfun Star Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((TGen 1 `fn` TAp (TGen 0) tBool)
               `fn` TAp tList (TGen 1)
               `fn` TAp (TGen 0) (TAp tList (TGen 1))))),
      [([PVar "p", PCon nilCfun []],
        ap [econst returnMfun, econst nilCfun]),
       ([PVar "p", PCon consCfun [PVar "x", PVar "xs"]],
        ap [econst mbindMfun,
            ap [evar "p", evar "x"],
            Lam ([PVar "b"],
                 ap [econst mbindMfun,
                     ap [evar "filterM", evar "p", evar "xs"],
                     Lam ([PVar "ys"],
                          ap [econst returnMfun,
                              If (evar "b")
                                 (ap [econst consCfun, evar "x", evar "ys"])
                                 (evar "ys")])])])])],
    [("liftM",
      Just (Forall [Kfun Star Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((TGen 1 `fn` TGen 2)
               `fn` TAp (TGen 0) (TGen 1)
               `fn` TAp (TGen 0) (TGen 2)))),
      [([PVar "f"],
        Lam ([PVar "a"],
             bind1 (evar "a") "a'" $
             ap [econst returnMfun,
                 ap [evar "f",
                     evar "a'"]]))])],
    [("liftM2",
      Just (Forall [Kfun Star Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((foldr1 fn . map TGen $ [1..3])
               `fn` TAp (TGen 0) (TGen 1)
               `fn` TAp (TGen 0) (TGen 2)
               `fn` TAp (TGen 0) (TGen 3)))),
      [([PVar "f"],
        Lam ([PVar "a", PVar "b"],
             bind1 (evar "a") "a'" .
             bind1 (evar "b") "b'" $
             ap [econst returnMfun,
                 ap [evar "f",
                     evar "a'",
                     evar "b'"]]))])],
    [("liftM3",
      Just (Forall [Kfun Star Star, Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((foldr1 fn . map TGen $ [1..4])
               `fn` TAp (TGen 0) (TGen 1)
               `fn` TAp (TGen 0) (TGen 2)
               `fn` TAp (TGen 0) (TGen 3)
               `fn` TAp (TGen 0) (TGen 4)))),
      [([PVar "f"],
        Lam ([PVar "a", PVar "b", PVar "c"],
             bind1 (evar "a") "a'" .
             bind1 (evar "b") "b'" .
             bind1 (evar "c") "c'" $
             ap [econst returnMfun,
                 ap [evar "f",
                     evar "a'",
                     evar "b'",
                     evar "c'"]]))])],
    [("liftM4",
      Just (Forall [Kfun Star Star, Star, Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((foldr1 fn . map TGen $ [1..5])
               `fn` TAp (TGen 0) (TGen 1)
               `fn` TAp (TGen 0) (TGen 2)
               `fn` TAp (TGen 0) (TGen 3)
               `fn` TAp (TGen 0) (TGen 4)
               `fn` TAp (TGen 0) (TGen 5)))),
      [([PVar "f"],
        Lam ([PVar "a", PVar "b", PVar "c", PVar "d"],
             bind1 (evar "a") "a'" .
             bind1 (evar "b") "b'" .
             bind1 (evar "c") "c'" .
             bind1 (evar "d") "d'" $
             ap [econst returnMfun,
                 ap [evar "f",
                     evar "a'",
                     evar "b'",
                     evar "c'",
                     evar "d'"]]))])],
    [("liftM5",
      Just (Forall [Kfun Star Star, Star, Star, Star, Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              ((foldr1 fn . map TGen $ [1..6])
               `fn` TAp (TGen 0) (TGen 1)
               `fn` TAp (TGen 0) (TGen 2)
               `fn` TAp (TGen 0) (TGen 3)
               `fn` TAp (TGen 0) (TGen 4)
               `fn` TAp (TGen 0) (TGen 5)
               `fn` TAp (TGen 0) (TGen 6)))),
      [([PVar "f"],
        Lam ([PVar "a", PVar "b", PVar "c", PVar "d", PVar "e"],
             bind1 (evar "a") "a'" .
             bind1 (evar "b") "b'" .
             bind1 (evar "c") "c'" .
             bind1 (evar "d") "d'" .
             bind1 (evar "e") "e'" $
             ap [econst returnMfun,
                 ap [evar "f",
                     evar "a'",
                     evar "b'",
                     evar "c'",
                     evar "d'",
                     evar "e'"]]))])],
    [("ap",
      Just (Forall [Kfun Star Star, Star, Star]
             ([isIn1 cMonad (TGen 0)] :=>
              (TAp (TGen 0) (TGen 1 `fn` TGen 2)
               `fn` TAp (TGen 0) (TGen 1)
               `fn` TAp (TGen 0) (TGen 2)))),
      [([],
        ap [evar "liftM2", evar "$"])])]]

-- Helper function, to keep liftM5 to one screen width...
bind1 :: Expr -> Id -> Expr -> Expr
bind1 monadExpr boundVar lambdaRHS
 = ap [econst mbindMfun, monadExpr, Lam ([PVar boundVar], lambdaRHS)]

-----------------------------------------------------------------------------
