-- |
-- Module      : MonusWeightedSearch.Examples.Dijkstra
-- Copyright   : (c) Donnacha Oisín Kidney 2021
-- Maintainer  : mail@doisinkidney.com
-- Stability   : experimental
-- Portability : non-portable
--
-- An implementation of Dijkstra's algorithm, using the 'HeapT' monad.
--
-- This is taken from section 6.1.3 of the paper
--
-- * Donnacha Oisín Kidney and Nicolas Wu. 2021. /Algebras for weighted search/.
--   Proc. ACM Program. Lang. 5, ICFP, Article 72 (August 2021), 30 pages.
--   DOI:<https://doi.org/10.1145/3473577>
--
-- This is a pretty simple implementation of the algorithm, defined monadically,
-- but it retains the time complexity of a standard purely functional
-- implementation.
--
-- We use the state monad here to avoid searching from the same node more than
-- once (which would lead to an infinite loop). Different algorithms use
-- different permutations of the monad transformers: for Dijkstra's algorithm,
-- we use @'HeapT' w ('State' ('Set' a)) a@, i.e. the 'HeapT' is outside of the
-- 'State'. This means that each branch of the search proceeds with a different
-- state; if we switch the order (to @'StateT' s ('Heap' w) a@, for example), we
-- get "global" state, which has the semantics of a /parser/. For an example
-- of that, see the module "MonusWeightedSearch.Examples.Parsing", where the
-- heap is used to implement a probabilistic parser.

module MonusWeightedSearch.Examples.Dijkstra where

import Prelude hiding (head)
import Control.Monad.State.Strict
import Control.Applicative
import Control.Monad.Writer
import Control.Monad
import Data.Foldable

import Data.Monus.Dist
import Data.Set (Set)
import qualified Data.Set as Set

import Data.List.NonEmpty (NonEmpty(..))

import Control.Monad.Heap

-- $setup
-- >>> import Prelude hiding (head)
-- >>> import Data.List.NonEmpty (head)

-- | The example graph from
-- <https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm the Wikipedia article on Dijkstra's algorithm>.
--
-- <<https://upload.wikimedia.org/wikipedia/commons/5/57/Dijkstra_Animation.gif>>
graph :: Graph Int
graph :: Graph Int
graph Int
1 = [(Int
2,Dist
7),(Int
3,Dist
9),(Int
6,Dist
14)]
graph Int
2 = [(Int
3,Dist
10),(Int
4,Dist
15)]
graph Int
3 = [(Int
4,Dist
11), (Int
6,Dist
2)]
graph Int
4 = [(Int
5,Dist
6)]
graph Int
5 = []
graph Int
6 = [(Int
5,Dist
9)]
graph Int
_ = []

-- | @'unique' x@ checks that @x@ has not yet been seen in this branch of the
-- computation.
unique :: Ord a => a -> HeapT w (State (Set a)) a
unique :: forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x = do
  Set a
seen <- HeapT w (State (Set a)) (Set a)
forall s (m :: Type -> Type). MonadState s m => m s
get
  Bool -> HeapT w (State (Set a)) ()
forall (f :: Type -> Type). Alternative f => Bool -> f ()
guard (a -> Set a -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.notMember a
x Set a
seen)
  (Set a -> Set a) -> HeapT w (State (Set a)) ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify (a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
Set.insert a
x)
  pure a
x
{-# INLINE unique #-}

-- | This is the Kleene star on the semiring of 'MonadPlus'. It is analagous to
-- the 'many' function on 'Alternative's.
star :: MonadPlus m => (a -> m a) -> a -> m a
star :: forall (m :: Type -> Type) a. MonadPlus m => (a -> m a) -> a -> m a
star a -> m a
f a
x = a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure a
x m a -> m a -> m a
forall (f :: Type -> Type) a. Alternative f => f a -> f a -> f a
<|> (a -> m a
f a
x m a -> (a -> m a) -> m a
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> m a) -> a -> m a
forall (m :: Type -> Type) a. MonadPlus m => (a -> m a) -> a -> m a
star a -> m a
f)
{-# INLINE star #-}

-- | This is a version of 'star' which keeps track of the inputs it was given.
pathed :: MonadPlus m => (a -> m a) -> a -> m (NonEmpty a)
pathed :: forall (m :: Type -> Type) a.
MonadPlus m =>
(a -> m a) -> a -> m (NonEmpty a)
pathed a -> m a
f = (NonEmpty a -> m (NonEmpty a)) -> NonEmpty a -> m (NonEmpty a)
forall (m :: Type -> Type) a. MonadPlus m => (a -> m a) -> a -> m a
star (\ ~(a
x :| [a]
xs) -> (a -> NonEmpty a) -> m a -> m (NonEmpty a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:|a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs) (a -> m a
f a
x)) (NonEmpty a -> m (NonEmpty a))
-> (a -> NonEmpty a) -> a -> m (NonEmpty a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [])
{-# INLINE pathed #-}

-- | Dijkstra's algorithm. This function returns the length of the shortest path
-- from a given vertex to every vertex in the graph.
--
-- >>> dijkstra graph 1
-- [(1,0),(2,7),(3,9),(6,11),(5,20),(4,20)]
--
-- A version which actually produces the paths is 'shortestPaths'
dijkstra :: Ord a => Graph a -> a -> [(a, Dist)]
dijkstra :: forall a. Ord a => Graph a -> Graph a
dijkstra Graph a
g a
x =
  State (Set a) [(a, Dist)] -> Set a -> [(a, Dist)]
forall s a. State s a -> s -> a
evalState (HeapT Dist (State (Set a)) a -> State (Set a) [(a, Dist)]
forall (m :: Type -> Type) w a.
(Monad m, Monus w) =>
HeapT w m a -> m [(a, w)]
searchT ((a -> HeapT Dist (State (Set a)) a)
-> a -> HeapT Dist (State (Set a)) a
forall (m :: Type -> Type) a. MonadPlus m => (a -> m a) -> a -> m a
star ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a
forall (t :: Type -> Type) (f :: Type -> Type) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a)
-> (a -> [HeapT Dist (State (Set a)) a])
-> a
-> HeapT Dist (State (Set a)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Dist) -> HeapT Dist (State (Set a)) a)
-> [(a, Dist)] -> [HeapT Dist (State (Set a)) a]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x,Dist
w) -> Dist -> HeapT Dist (State (Set a)) ()
forall w (m :: Type -> Type). MonadWriter w m => w -> m ()
tell Dist
w HeapT Dist (State (Set a)) ()
-> HeapT Dist (State (Set a)) a -> HeapT Dist (State (Set a)) a
forall (m :: Type -> Type) a b. Monad m => m a -> m b -> m b
>> a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x) ([(a, Dist)] -> [HeapT Dist (State (Set a)) a])
-> Graph a -> a -> [HeapT Dist (State (Set a)) a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph a
g) (a -> HeapT Dist (State (Set a)) a)
-> HeapT Dist (State (Set a)) a -> HeapT Dist (State (Set a)) a
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x)) Set a
forall a. Set a
Set.empty
{-# INLINE dijkstra #-}

-- | Dijkstra's algorithm, which produces a path.
--
-- The only difference between this function and 'shortestPaths' is that this
-- uses 'pathed' rather than 'star'.
--
-- The following finds the shortest path from vertex 1 to 5:
--
-- >>> filter ((5==) . head . fst) (shortestPaths graph 1)
-- [(5 :| [6,3,1],20)]
--
-- And it is indeed @[1,3,6,5]@. (it's returned in reverse)
shortestPaths :: Ord a => Graph a -> a -> [(NonEmpty a, Dist)]
shortestPaths :: forall a. Ord a => Graph a -> a -> [(NonEmpty a, Dist)]
shortestPaths Graph a
g a
x =
  State (Set a) [(NonEmpty a, Dist)] -> Set a -> [(NonEmpty a, Dist)]
forall s a. State s a -> s -> a
evalState (HeapT Dist (State (Set a)) (NonEmpty a)
-> State (Set a) [(NonEmpty a, Dist)]
forall (m :: Type -> Type) w a.
(Monad m, Monus w) =>
HeapT w m a -> m [(a, w)]
searchT ((a -> HeapT Dist (State (Set a)) a)
-> a -> HeapT Dist (State (Set a)) (NonEmpty a)
forall (m :: Type -> Type) a.
MonadPlus m =>
(a -> m a) -> a -> m (NonEmpty a)
pathed ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a
forall (t :: Type -> Type) (f :: Type -> Type) a.
(Foldable t, Alternative f) =>
t (f a) -> f a
asum ([HeapT Dist (State (Set a)) a] -> HeapT Dist (State (Set a)) a)
-> (a -> [HeapT Dist (State (Set a)) a])
-> a
-> HeapT Dist (State (Set a)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Dist) -> HeapT Dist (State (Set a)) a)
-> [(a, Dist)] -> [HeapT Dist (State (Set a)) a]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x,Dist
w) -> Dist -> HeapT Dist (State (Set a)) ()
forall w (m :: Type -> Type). MonadWriter w m => w -> m ()
tell Dist
w HeapT Dist (State (Set a)) ()
-> HeapT Dist (State (Set a)) a -> HeapT Dist (State (Set a)) a
forall (m :: Type -> Type) a b. Monad m => m a -> m b -> m b
>> a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x) ([(a, Dist)] -> [HeapT Dist (State (Set a)) a])
-> Graph a -> a -> [HeapT Dist (State (Set a)) a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Graph a
g) (a -> HeapT Dist (State (Set a)) (NonEmpty a))
-> HeapT Dist (State (Set a)) a
-> HeapT Dist (State (Set a)) (NonEmpty a)
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< a -> HeapT Dist (State (Set a)) a
forall a w. Ord a => a -> HeapT w (State (Set a)) a
unique a
x)) Set a
forall a. Set a
Set.empty
{-# INLINE shortestPaths #-}