{-
This version should correspond exactly to what is appears in the published paper.
It also includes essential support code omitted from the paper for lack of space.
-}

import Array
import List
import IOExts
import System

--------------------------------------------------------------------------- 
-- Figure 1
--------------------------------------------------------------------------- 
type Var = Int
type Value = Int

data Assignment = Var := Value 

var :: Assignment -> Var
var (var := _) = var

value :: Assignment -> Value
value (_ := val) = val

type Relation = Assignment -> Assignment -> Bool

data CSP = CSP {vars, vals :: Int, rel :: Relation} 

data State = State ([Assignment],[Var]) 

assignments :: State -> [Assignment]
assignments (State(as,_)) = as

unassigned :: State -> [Var]
unassigned (State(_,us)) = us

emptyState :: CSP -> State
emptyState CSP{vars=vars} = State([],[1..vars])

isEmptyState :: State -> Bool
isEmptyState = null . assignments

extensions :: CSP -> State -> [State]
extensions _              (State(_,[])) = []
extensions CSP{vals=vals} (State(as,nextvar:rest)) = 
       [State((nextvar := val):as,rest) | val <- [1..vals]]

newNextVar :: State -> Var -> State
newNextVar s@(State(as,[])) _    = s
newNextVar s@(State(as,us)) next = State(as,next:delete next us)

complete :: State -> Bool
complete = null . unassigned

lastAssignment :: State -> Assignment
lastAssignment = head . assignments 

nextVar :: State -> Var
nextVar = head . unassigned
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 2
--------------------------------------------------------------------------- 
generate :: CSP -> [State]
generate csp@CSP{vars=vars} = g vars
  where g 0 = [emptyState csp]
        g var = concat [extensions csp st | st <- g (var-1)]

inconsistencies :: CSP -> State -> [(Var, Var)]
inconsistencies CSP{rel=rel} st =
  [ (var a, var b) | a <- as, b <- as, var a > var b, not (rel a b) ]
        where as = assignments st

consistent :: CSP -> State -> Bool
consistent csp = null . (inconsistencies csp)

test :: CSP -> [State] -> [State]
test csp = filter (consistent csp)

solver :: CSP -> [State]
solver csp  = test csp candidates where candidates = generate csp
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 3
--------------------------------------------------------------------------- 
queens :: Int -> CSP
queens n = CSP{vals=n,vars=n,rel=safe}
  where safe (col1 := row1) (col2 := row2) = 
         (row1 /= row2) && abs (col1 - col2) /= abs (row1 - row2)

graphcoloring :: Int -> ((Var,Var) -> Bool) -> Int -> CSP 
graphcoloring nodes adj colors = CSP{vars=nodes,vals=colors,rel=ok}
   where ok (n1 := c1) (n2 := c2) = c1 /= c2 || not (adj (n1,n2))
--------------------------------------------------------------------------- 


--------------------------------------------------------------------------- 
-- Figure 4
--------------------------------------------------------------------------- 
data Tree a = Node a [Tree a] 

mkTree :: a -> [Tree a] -> Tree a
mkTree a ts = Node a ts

label :: Tree a -> a
label (Node a _) = a

initTree :: (a -> [a]) -> a -> Tree a
initTree f a = Node a (map (initTree f) (f a))

mapTree  :: (a -> b) -> Tree a -> Tree b
mapTree f (Node a ts) = Node (f a) (map (mapTree f) ts)

foldTree :: (a -> [b] -> b) -> Tree a -> b
foldTree f (Node a ts) = f a (map (foldTree f) ts)

zipTreesWith :: (a -> b -> c) -> Tree a -> Tree b -> Tree c
zipTreesWith f (Node a ts) (Node b us) = 
                  Node (f a b) (zipWith (zipTreesWith f) ts us)

prune :: (a -> Bool) -> Tree a -> Tree a
prune p = foldTree f
  where f a ts = Node a (filter (not . p . label) ts)

leaves :: Tree a -> [a]
leaves = foldTree f 
   where f leaf [] = [leaf]
         f _    ts = concat ts

inhTree :: (b -> a -> b) -> b -> Tree a -> Tree b
inhTree f b (Node a ts) = Node b' (map (inhTree f b') ts)
    where b' = f b a

distrTree :: (a -> [b]) -> b -> Tree a -> Tree b
distrTree f b (Node a ts) = Node b (zipWith (distrTree f) (f a) ts) 
--------------------------------------------------------------------------- 


--------------------------------------------------------------------------- 
-- Figure 5
--------------------------------------------------------------------------- 
mkSearchTree :: CSP -> Tree State
mkSearchTree csp = initTree (extensions csp) (emptyState csp)

earliestInconsistency:: CSP -> State -> Maybe Var
earliestInconsistency CSP{rel=rel} st = 
  case assignments st of
    [] -> Nothing
    (a:as) -> case filter (not . rel a) (reverse as) of
                [] -> Nothing
                (b:_) -> Just(var b)

labelInconsistencies :: CSP -> Tree State -> Tree (State,Maybe Var)
labelInconsistencies csp = mapTree f 
    where f s = (s,earliestInconsistency csp s)

btsolver0 :: CSP -> [State]
btsolver0 csp =
  (filter complete . map fst . leaves . prune ((/= Nothing) . snd) 
                             . (labelInconsistencies csp) .  mkSearchTree) csp
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 7
--------------------------------------------------------------------------- 
type ConflictSet = OrderedSet Var

isConflict :: ConflictSet -> Bool
isConflict = not . isEmptySet

solutions :: Tree (State, ConflictSet) -> [State]
solutions = filter complete . map fst . leaves . prune (isConflict . snd)

type Labeler = CSP -> Tree State -> Tree (State, ConflictSet)

search :: Labeler -> CSP -> [State]
search labeler csp = (solutions . (labeler csp) . mkSearchTree) csp

bt :: Labeler 
bt csp = mapTree f
  where f s = (s, case earliestInconsistency csp s of
                    Nothing -> emptySet
                    Just a -> listToSet [var (lastAssignment s),a])

btsolver :: CSP -> [State]
btsolver = search bt
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 8
--------------------------------------------------------------------------- 
emptySet :: Ord a => OrderedSet a
isEmptySet :: Ord a => OrderedSet a -> Bool
memberSet :: Ord a => OrderedSet a -> a -> Bool
unionSet :: Ord a => OrderedSet a -> OrderedSet a -> OrderedSet a
intersectSet :: Ord a => OrderedSet a -> OrderedSet a -> OrderedSet a
removeFromSet :: Ord a => a -> OrderedSet a -> OrderedSet a
listToSet :: Ord a => [a] -> OrderedSet a
evalSet :: Ord a => OrderedSet a -> OrderedSet a
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Ordered set implementations (NOT IN PAPER)
--------------------------------------------------------------------------- 
newtype Ord a => OrderedSet a = OrderedSet [a] deriving Show  -- held in *decreasing* order

emptySet = OrderedSet [] 

isEmptySet (OrderedSet s) = null s

singletonSet :: Ord a => a -> OrderedSet a 
singletonSet x = OrderedSet [x]

memberSet (OrderedSet s) x = f s 
    where f []  = False
          f (h:t) = 
	      case compare h x of
                EQ -> True
                LT -> False
                GT -> f t

unionSet (OrderedSet s1) (OrderedSet s2) =  OrderedSet (f s1 s2) 
    where f [] s2  = s2
          f s1 []  = s1
          f s1@(h1:t1) s2@(h2:t2) =
              case compare h1 h2 of
                EQ ->  h1:(f t1 t2)
                LT ->  h2:(f s1 t2)
                GT ->  h1:(f t1 s2)

intersectSet (OrderedSet s1) (OrderedSet s2) = OrderedSet (f s1 s2)
    where f [] s2 = []
          f s1 [] = []
          f s1@(h1:t1) s2@(h2:t2) =
	      case compare h1 h2 of
                EQ -> h1:(f t1 t2)
		LT -> f s1 t2 
                GT -> f t1 s2

diffSet (OrderedSet s1) (OrderedSet s2) = OrderedSet (f s1 s2)
    where f [] s2 = [] 
          f s1 [] = s1
          f s1@(h1:t1) s2@(h2:t2) =
              case compare h1 h2 of
                EQ -> f t1 t2
		LT -> f s1 t2 
                GT -> h1:(f t1 s2)

addToSet :: Ord a => a -> OrderedSet a -> OrderedSet a
addToSet a s = unionSet s (singletonSet a)

extendSet :: Ord a => OrderedSet a -> [a] -> OrderedSet a
extendSet = foldr addToSet

listToSet = extendSet emptySet

removeFromSet a s = diffSet s (singletonSet a)

retractSet :: Ord a => OrderedSet a -> [a] -> OrderedSet a
retractSet = foldr removeFromSet

cardSet :: Ord a => OrderedSet a -> Int
cardSet (OrderedSet s) = length s

evalSet (OrderedSet s) = OrderedSet (seq (length s) s)
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 10
--------------------------------------------------------------------------- 
hrandom :: Int -> Tree a -> Tree a

hrandom seed t = foldTree g t seed 
   where g a ts seed = 
              mkTree a (randomizeList seed' (zipWith ($) ts (randoms seed')))
		where seed' = random seed

btr :: Int -> Labeler 
btr seed csp = bt csp . hrandom seed
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Support for random numbers (NOT IN PAPER)
-- (This should be revised to use standard Haskell98 library random functions.)
--------------------------------------------------------------------------- 
random2 :: Int -> Int
random2 n = if test > 0 then test else test + 2147483647
  where test = 16807 * lo - 2836 * hi
        hi   = n `div` 127773
        lo   = n `rem` 127773

randoms :: Int -> [Int]
randoms = iterate random2

random :: Int -> Int
random n = (a * n + c) -- mod m
  where a = 994108973
        c = a

randomizeList :: Int -> [a] -> [a]
randomizeList i as = map snd (sortBy (\(a,b) (c,d) -> compare a c) (zip (randoms i) as))
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 11
--------------------------------------------------------------------------- 
bj0bt :: Labeler 
bj0bt csp = bj0 csp . bt csp

bj0 :: CSP -> Tree (State, ConflictSet) -> Tree (State, ConflictSet)
bj0 csp = foldTree f 
  where f (s, cs) ts
           | isConflict cs  = mkTree (s, cs)  ts
           | otherwise      = mkTree (s, cs') ts
               where cs' = combine (map label ts) []

unionCS :: [ConflictSet] -> ConflictSet
unionCS css = foldr unionSet emptySet css

combine ::  [(State, ConflictSet)] -> [ConflictSet] -> ConflictSet
combine []           acc = unionCS acc 
combine ((s, cs):ns) acc
  | not (memberSet cs lastvar) = cs
  | isEmptySet cs = emptySet
  | otherwise = combine ns ((removeFromSet lastvar cs):acc)
       where lastvar = var (lastAssignment s)

bjbt :: Labeler 
bjbt csp = bj csp . bt csp

bj :: CSP -> Tree (State, ConflictSet) -> Tree (State, ConflictSet)
bj csp = foldTree f 
  where f (s, cs) ts
           | isConflict cs  = mkTree (s, cs)  ts
           | isConflict cs' = mkTree (s, cs') []  -- plug first leak
           | otherwise      = mkTree (s, cs') ts
               where cs' = evalSet(combine (map label ts) []) -- plug second leak
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 13
--------------------------------------------------------------------------- 
bm :: Labeler 
bm csp = extractConflicts . storeConflicts csp

storeConflicts :: CSP -> Tree State -> Tree (State,Cache ConflictSet)
storeConflicts csp = inhTree f (undefined,undefined)
     where f (_,tbl) s = (s,augmentConflicts csp tbl s)

augmentConflicts :: CSP -> Cache ConflictSet -> State -> Cache ConflictSet
augmentConflicts csp@CSP{rel=rel} parentTbl s 
    | isEmptyState s = initCache csp emptySet
    | otherwise = mapCache extendCS tbl 
            where tbl = thinCache parentTbl (var lasta)
                  extendCS :: Assignment -> ConflictSet -> ConflictSet
                  extendCS a cs
                      | isConflict cs = cs
                      | rel lasta a   = emptySet
                      | otherwise     = listToSet [var lasta, var a]
                  lasta = lastAssignment s


extractConflicts :: Tree (State,Cache ConflictSet) -> Tree (State,ConflictSet)
extractConflicts t = zipTreesWith g t t'
  where t' = distrTree f emptySet t
        f (s,tbl) = lookupCache tbl (nextVar s)
        g (s,_) cs = (s,cs) 
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 14
--------------------------------------------------------------------------- 
data Cache a = Cache [(Var,[a])]

initCache :: CSP -> a -> Cache a
initCache CSP{vars=vars,vals=vals} i = Cache (zip [1..vars] (repeat row))
  where row = take vals (repeat i)

thinCache :: Cache a -> Var -> Cache a 
thinCache (Cache cache) var0 = Cache [(var,row) | (var,row) <- cache, var /= var0]

mapCache :: (Assignment -> a -> a) -> Cache a -> Cache a
mapCache f (Cache cache) = 
    Cache [(var, newRow var row) | (var,row) <- cache]
         where newRow var row =  [ f (var := val) a | (val, a) <- zip [1..] row ]

lookupCache :: Cache a -> Var -> [a]
lookupCache (Cache cache) var = val
   where Just val = lookup var cache

getCache :: Cache a -> [(Var,[a])]
getCache (Cache cache) = cache
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 15
--------------------------------------------------------------------------- 
mfc :: Labeler 
mfc csp = mfc' csp . storeConflicts csp

mfc' :: CSP -> Tree (State,Cache ConflictSet) -> Tree (State,ConflictSet)
mfc' csp t = zipTreesWith f (extractConflicts t) (mapTree (wipedDomain csp) t)
               where f (s,cs) cs' | isConflict cs = (s,cs)
                                  | otherwise     = (s,cs')

wipedDomain :: CSP -> (State, Cache ConflictSet) -> ConflictSet
wipedDomain CSP{vars=vars} (s,tbl) 
  | null wipedDomains = emptySet
  | otherwise =  intersectSet (unionCS (head wipedDomains)) 
                              (listToSet (map var (assignments s)))
    where wipedDomains :: [[ConflictSet]]
          wipedDomains = [css | (v,css) <- getCache tbl, all isConflict css]
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 16
--------------------------------------------------------------------------- 
type DVOParams a = (CSP -> Tree (State,a) -> Tree (State,ConflictSet),
                    CSP -> a -> State -> Var,
                    CSP -> a -> State -> a)

searchDVO :: DVOParams a -> CSP -> [State]
searchDVO (relabeler,selector,prelabeler) csp = 
           (solutions . (relabeler csp) . 
                        mkSearchTreeDVO (prelabeler csp) (selector csp)) csp

mkSearchTreeDVO :: (a -> State -> a) -> (a -> State -> Var) -> CSP -> Tree (State,a)
mkSearchTreeDVO prelabeler selector csp = initTree mk (root_s,root_a)
  where  mk (s,a) = [(newNextVar s' (selector a' s'),a') | 
	                        s' <- extensions csp s, 
                                let a' = prelabeler a s']
         root_a = prelabeler undefined root_s
         root_s = emptyState csp
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 17
--------------------------------------------------------------------------- 
failFirst0 :: Cache ConflictSet -> State -> Var
failFirst0 tbl _ = var
  where (var,_) = foldr1 smallerDomain sizedDomains
        smallerDomain a@(_,asize) b@(_,bsize) = if asize <= bsize then a else b
        sizedDomains = [(var,length (filter (not . isConflict) css))  
                                             | (var,css) <- getCache tbl]

ff0 :: DVOParams (Cache ConflictSet)
ff0 = (const extractConflicts,const failFirst0,augmentConflicts)

ff0solver :: CSP -> [State]
ff0solver = searchDVO ff0

failFirst :: Cache ConflictSet -> State -> Var
failFirst tbl _ = var
  where (var,_) = foldr1 smallerDomain sizedDomains
        smallerDomain a@(_,asize) b@(_,bsize) = if asize `nleq` bsize then a else b
        sizedDomains = [(var,nlength (filter (not . isConflict) css)) 
                                             | (var,css) <- getCache tbl]

ff :: DVOParams (Cache ConflictSet)
ff = (const extractConflicts,const failFirst,augmentConflicts)

failFirst1 :: Cache ConflictSet -> State -> Var
failFirst1 tbl _ = var
  where (var,_) = smallestDomain sizedDomains
        smallestDomain domains =
          case emptyDomains of
            d:_ -> d
            []  -> smallestDomain (map f domains) 
		      where f (var,n) = (var,npred n)
	  where emptyDomains = filter (isZ . snd) domains	   
        sizedDomains = [(var,nlength (filter (not . isConflict) css)) 
                                             | (var,css) <- getCache tbl]

ff1 :: DVOParams (Cache ConflictSet)
ff1 = (const extractConflicts,const failFirst1,augmentConflicts)
--------------------------------------------------------------------------- 

--------------------------------------------------------------------------- 
-- Figure 18
--------------------------------------------------------------------------- 
data Nat = Z | S Nat 

nleq :: Nat -> Nat -> Bool
nleq  Z     _      = True
nleq  _     Z      = False
nleq (S n1) (S n2) = nleq n1 n2

isZ :: Nat -> Bool
isZ Z = True
isZ _ = False

npred :: Nat -> Nat
npred (S n) = n
npred Z     = error "Pred"

nlength :: [a] -> Nat
nlength []     = Z
nlength (a:as) = S(nlength as)
--------------------------------------------------------------------------- 

bjmfc :: Labeler
bjmfc csp = bj csp . mfc csp

bjbm :: Labeler
bjbm csp = bj csp . bm csp

mfcff :: DVOParams (Cache ConflictSet)
mfcff = (mfc',const failFirst,augmentConflicts)

mfcff1 :: DVOParams (Cache ConflictSet)
mfcff1 = (mfc',const failFirst1,augmentConflicts)

bjff1 :: DVOParams (Cache ConflictSet)
bjff1 = (\csp -> bj csp.extractConflicts,const failFirst1,augmentConflicts)

