{-# LANGUAGE DefaultSignatures     #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}

module Control.Monad.Class.MonadAsync
  ( MonadAsync (..)
  , MonadAsyncSTM (..)
  , AsyncCancelled(..)
  , ExceptionInLinkedThread(..)
  , link
  , linkTo
  , linkOnly
  , linkToOnly

  , mapConcurrently, forConcurrently
  , mapConcurrently_, forConcurrently_
  , replicateConcurrently, replicateConcurrently_
  , Concurrently (..)
  ) where

import           Prelude hiding (read)

import           Control.Applicative (Alternative (..), liftA2)
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadTimer
import           Control.Monad.Class.MonadThrow

import           Control.Concurrent.Async (AsyncCancelled (..))
import qualified Control.Concurrent.Async as Async
import qualified Control.Exception as E
import           Control.Monad.Reader
import qualified Control.Monad.STM as STM

import           Data.Foldable (fold)
import           Data.Kind (Type)
import           Data.Proxy

class (Functor async, MonadSTMTx stm) => MonadAsyncSTM async stm where
  {-# MINIMAL waitCatchSTM, pollSTM #-}

  waitSTM      :: async a -> stm a
  pollSTM      :: async a -> stm (Maybe (Either SomeException a))
  waitCatchSTM :: async a -> stm (Either SomeException a)

  default waitSTM :: MonadThrow stm => async a -> stm a
  waitSTM async a
action = async a -> stm (Either SomeException a)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm (Either SomeException a)
waitCatchSTM async a
action stm (Either SomeException a)
-> (Either SomeException a -> stm a) -> stm a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (SomeException -> stm a)
-> (a -> stm a) -> Either SomeException a -> stm a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> stm a
forall (stm :: * -> *) e a.
(MonadSTMTx stm, MonadThrow stm, Exception e) =>
e -> stm a
throwSTM a -> stm a
forall (m :: * -> *) a. Monad m => a -> m a
return

  waitAnySTM            :: [async a] -> stm (async a, a)
  waitAnyCatchSTM       :: [async a] -> stm (async a, Either SomeException a)
  waitEitherSTM         :: async a -> async b -> stm (Either a b)
  waitEitherSTM_        :: async a -> async b -> stm ()
  waitEitherCatchSTM    :: async a -> async b
                        -> stm (Either (Either SomeException a)
                                         (Either SomeException b))
  waitBothSTM           :: async a -> async b -> stm (a, b)

  default waitAnySTM     :: MonadThrow stm => [async a] -> stm (async a, a)
  default waitEitherSTM  :: MonadThrow stm => async a -> async b -> stm (Either a b)
  default waitEitherSTM_ :: MonadThrow stm => async a -> async b -> stm ()
  default waitBothSTM    :: MonadThrow stm => async a -> async b -> stm (a, b)

  waitAnySTM [async a]
as =
    (stm (async a, a) -> stm (async a, a) -> stm (async a, a))
-> stm (async a, a) -> [stm (async a, a)] -> stm (async a, a)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr stm (async a, a) -> stm (async a, a) -> stm (async a, a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
orElse stm (async a, a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry ([stm (async a, a)] -> stm (async a, a))
-> [stm (async a, a)] -> stm (async a, a)
forall a b. (a -> b) -> a -> b
$
      (async a -> stm (async a, a)) -> [async a] -> [stm (async a, a)]
forall a b. (a -> b) -> [a] -> [b]
map (\async a
a -> do a
r <- async a -> stm a
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async a
a; (async a, a) -> stm (async a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (async a
a, a
r)) [async a]
as

  waitAnyCatchSTM [async a]
as =
    (stm (async a, Either SomeException a)
 -> stm (async a, Either SomeException a)
 -> stm (async a, Either SomeException a))
-> stm (async a, Either SomeException a)
-> [stm (async a, Either SomeException a)]
-> stm (async a, Either SomeException a)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr stm (async a, Either SomeException a)
-> stm (async a, Either SomeException a)
-> stm (async a, Either SomeException a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
orElse stm (async a, Either SomeException a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry ([stm (async a, Either SomeException a)]
 -> stm (async a, Either SomeException a))
-> [stm (async a, Either SomeException a)]
-> stm (async a, Either SomeException a)
forall a b. (a -> b) -> a -> b
$
      (async a -> stm (async a, Either SomeException a))
-> [async a] -> [stm (async a, Either SomeException a)]
forall a b. (a -> b) -> [a] -> [b]
map (\async a
a -> do Either SomeException a
r <- async a -> stm (Either SomeException a)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm (Either SomeException a)
waitCatchSTM async a
a; (async a, Either SomeException a)
-> stm (async a, Either SomeException a)
forall (m :: * -> *) a. Monad m => a -> m a
return (async a
a, Either SomeException a
r)) [async a]
as

  waitEitherSTM async a
left async b
right =
    (a -> Either a b
forall a b. a -> Either a b
Left  (a -> Either a b) -> stm a -> stm (Either a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> async a -> stm a
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async a
left)
      stm (Either a b) -> stm (Either a b) -> stm (Either a b)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
`orElse`
    (b -> Either a b
forall a b. b -> Either a b
Right (b -> Either a b) -> stm b -> stm (Either a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> async b -> stm b
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async b
right)

  waitEitherSTM_ async a
left async b
right =
      (stm a -> stm ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (stm a -> stm ()) -> stm a -> stm ()
forall a b. (a -> b) -> a -> b
$ async a -> stm a
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async a
left)
        stm () -> stm () -> stm ()
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
`orElse`
      (stm b -> stm ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (stm b -> stm ()) -> stm b -> stm ()
forall a b. (a -> b) -> a -> b
$ async b -> stm b
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async b
right)

  waitEitherCatchSTM async a
left async b
right =
      (Either SomeException a
-> Either (Either SomeException a) (Either SomeException b)
forall a b. a -> Either a b
Left  (Either SomeException a
 -> Either (Either SomeException a) (Either SomeException b))
-> stm (Either SomeException a)
-> stm (Either (Either SomeException a) (Either SomeException b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> async a -> stm (Either SomeException a)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm (Either SomeException a)
waitCatchSTM async a
left)
        stm (Either (Either SomeException a) (Either SomeException b))
-> stm (Either (Either SomeException a) (Either SomeException b))
-> stm (Either (Either SomeException a) (Either SomeException b))
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
`orElse`
      (Either SomeException b
-> Either (Either SomeException a) (Either SomeException b)
forall a b. b -> Either a b
Right (Either SomeException b
 -> Either (Either SomeException a) (Either SomeException b))
-> stm (Either SomeException b)
-> stm (Either (Either SomeException a) (Either SomeException b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> async b -> stm (Either SomeException b)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm (Either SomeException a)
waitCatchSTM async b
right)

  waitBothSTM async a
left async b
right = do
      a
a <- async a -> stm a
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async a
left
             stm a -> stm a -> stm a
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
`orElse`
           (async b -> stm b
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async b
right stm b -> stm a -> stm a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> stm a
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry)
      b
b <- async b -> stm b
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM async b
right
      (a, b) -> stm (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a,b
b)

class ( MonadSTM m
      , MonadThread m
      , MonadAsyncSTM (Async m) (STM m)
      ) => MonadAsync m where

  {-# MINIMAL async, asyncThreadId, cancel, cancelWith, asyncWithUnmask #-}

  -- | An asynchronous action
  type Async m :: Type -> Type

  async                 :: m a -> m (Async m a)
  asyncThreadId         :: Proxy m -> Async m a -> ThreadId m
  withAsync             :: m a -> (Async m a -> m b) -> m b

  wait                  :: Async m a -> m a
  poll                  :: Async m a -> m (Maybe (Either SomeException a))
  waitCatch             :: Async m a -> m (Either SomeException a)
  cancel                :: Async m a -> m ()
  cancelWith            :: Exception e => Async m a -> e -> m ()
  uninterruptibleCancel :: Async m a -> m ()

  waitAny               :: [Async m a] -> m (Async m a, a)
  waitAnyCatch          :: [Async m a] -> m (Async m a, Either SomeException a)
  waitAnyCancel         :: [Async m a] -> m (Async m a, a)
  waitAnyCatchCancel    :: [Async m a] -> m (Async m a, Either SomeException a)
  waitEither            :: Async m a -> Async m b -> m (Either a b)

  -- | Note, IO-based implementations should override the default
  -- implementation. See the @async@ package implementation and comments.
  -- <http://hackage.haskell.org/package/async-2.2.1/docs/src/Control.Concurrent.Async.html#waitEitherCatch>
  waitEitherCatch       :: Async m a -> Async m b -> m (Either (Either SomeException a)
                                                               (Either SomeException b))
  waitEitherCancel      :: Async m a -> Async m b -> m (Either a b)
  waitEitherCatchCancel :: Async m a -> Async m b -> m (Either (Either SomeException a)
                                                               (Either SomeException b))
  waitEither_           :: Async m a -> Async m b -> m ()
  waitBoth              :: Async m a -> Async m b -> m (a, b)

  race                  :: m a -> m b -> m (Either a b)
  race_                 :: m a -> m b -> m ()
  concurrently          :: m a -> m b -> m (a,b)
  concurrently_         :: m a -> m b -> m ()

  asyncWithUnmask       :: ((forall b . m b -> m b) ->  m a) -> m (Async m a)

  -- default implementations
  default withAsync     :: MonadMask m => m a -> (Async m a -> m b) -> m b
  default uninterruptibleCancel
                        :: MonadMask m => Async m a -> m ()
  default waitAnyCancel         :: MonadThrow m => [Async m a] -> m (Async m a, a)
  default waitAnyCatchCancel    :: MonadThrow m => [Async m a]
                                -> m (Async m a, Either SomeException a)
  default waitEitherCancel      :: MonadThrow m => Async m a -> Async m b
                                -> m (Either a b)
  default waitEitherCatchCancel :: MonadThrow m => Async m a -> Async m b
                                -> m (Either (Either SomeException a)
                                             (Either SomeException b))

  withAsync m a
action Async m a -> m b
inner = ((forall a. m a -> m a) -> m b) -> m b
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m b) -> m b)
-> ((forall a. m a -> m a) -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
                             Async m a
a <- m a -> m (Async m a)
forall (m :: * -> *) a. MonadAsync m => m a -> m (Async m a)
async (m a -> m a
forall a. m a -> m a
restore m a
action)
                             m b -> m b
forall a. m a -> m a
restore (Async m a -> m b
inner Async m a
a)
                               m b -> m () -> m b
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
`finally` Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
uninterruptibleCancel Async m a
a

  wait      = STM m a -> m a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m a -> m a) -> (Async m a -> STM m a) -> Async m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async m a -> STM m a
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm a
waitSTM
  poll      = STM m (Maybe (Either SomeException a))
-> m (Maybe (Either SomeException a))
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe (Either SomeException a))
 -> m (Maybe (Either SomeException a)))
-> (Async m a -> STM m (Maybe (Either SomeException a)))
-> Async m a
-> m (Maybe (Either SomeException a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async m a -> STM m (Maybe (Either SomeException a))
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm (Maybe (Either SomeException a))
pollSTM
  waitCatch = STM m (Either SomeException a) -> m (Either SomeException a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Either SomeException a) -> m (Either SomeException a))
-> (Async m a -> STM m (Either SomeException a))
-> Async m a
-> m (Either SomeException a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async m a -> STM m (Either SomeException a)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
async a -> stm (Either SomeException a)
waitCatchSTM

  uninterruptibleCancel      = m () -> m ()
forall (m :: * -> *) a. MonadMask m => m a -> m a
uninterruptibleMask_ (m () -> m ()) -> (Async m a -> m ()) -> Async m a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel

  waitAny                    = STM m (Async m a, a) -> m (Async m a, a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Async m a, a) -> m (Async m a, a))
-> ([Async m a] -> STM m (Async m a, a))
-> [Async m a]
-> m (Async m a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Async m a] -> STM m (Async m a, a)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
[async a] -> stm (async a, a)
waitAnySTM
  waitAnyCatch               = STM m (Async m a, Either SomeException a)
-> m (Async m a, Either SomeException a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Async m a, Either SomeException a)
 -> m (Async m a, Either SomeException a))
-> ([Async m a] -> STM m (Async m a, Either SomeException a))
-> [Async m a]
-> m (Async m a, Either SomeException a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Async m a] -> STM m (Async m a, Either SomeException a)
forall (async :: * -> *) (stm :: * -> *) a.
MonadAsyncSTM async stm =>
[async a] -> stm (async a, Either SomeException a)
waitAnyCatchSTM
  waitEither      Async m a
left Async m b
right = STM m (Either a b) -> m (Either a b)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (Async m a -> Async m b -> STM m (Either a b)
forall (async :: * -> *) (stm :: * -> *) a b.
MonadAsyncSTM async stm =>
async a -> async b -> stm (Either a b)
waitEitherSTM Async m a
left Async m b
right)
  waitEither_     Async m a
left Async m b
right = STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (Async m a -> Async m b -> STM m ()
forall (async :: * -> *) (stm :: * -> *) a b.
MonadAsyncSTM async stm =>
async a -> async b -> stm ()
waitEitherSTM_ Async m a
left Async m b
right)
  waitEitherCatch Async m a
left Async m b
right = STM m (Either (Either SomeException a) (Either SomeException b))
-> m (Either (Either SomeException a) (Either SomeException b))
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (Async m a
-> Async m b
-> STM m (Either (Either SomeException a) (Either SomeException b))
forall (async :: * -> *) (stm :: * -> *) a b.
MonadAsyncSTM async stm =>
async a
-> async b
-> stm (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchSTM Async m a
left Async m b
right)
  waitBoth        Async m a
left Async m b
right = STM m (a, b) -> m (a, b)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (Async m a -> Async m b -> STM m (a, b)
forall (async :: * -> *) (stm :: * -> *) a b.
MonadAsyncSTM async stm =>
async a -> async b -> stm (a, b)
waitBothSTM Async m a
left Async m b
right)

  waitAnyCancel [Async m a]
asyncs =
    [Async m a] -> m (Async m a, a)
forall (m :: * -> *) a.
MonadAsync m =>
[Async m a] -> m (Async m a, a)
waitAny [Async m a]
asyncs m (Async m a, a) -> m () -> m (Async m a, a)
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
`finally` (Async m a -> m ()) -> [Async m a] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel [Async m a]
asyncs

  waitAnyCatchCancel [Async m a]
asyncs =
    [Async m a] -> m (Async m a, Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
[Async m a] -> m (Async m a, Either SomeException a)
waitAnyCatch [Async m a]
asyncs m (Async m a, Either SomeException a)
-> m () -> m (Async m a, Either SomeException a)
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
`finally` (Async m a -> m ()) -> [Async m a] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel [Async m a]
asyncs

  waitEitherCancel Async m a
left Async m b
right =
    Async m a -> Async m b -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m (Either a b)
waitEither Async m a
left Async m b
right m (Either a b) -> m () -> m (Either a b)
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
`finally` (Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Async m a
left m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Async m b -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Async m b
right)

  waitEitherCatchCancel Async m a
left Async m b
right =
    Async m a
-> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a
-> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch Async m a
left Async m b
right m (Either (Either SomeException a) (Either SomeException b))
-> m ()
-> m (Either (Either SomeException a) (Either SomeException b))
forall (m :: * -> *) a b. MonadThrow m => m a -> m b -> m a
`finally` (Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Async m a
left m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Async m b -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Async m b
right)

  race            m a
left m b
right = m a -> (Async m a -> m (Either a b)) -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
left  ((Async m a -> m (Either a b)) -> m (Either a b))
-> (Async m a -> m (Either a b)) -> m (Either a b)
forall a b. (a -> b) -> a -> b
$ \Async m a
a ->
                               m b -> (Async m b -> m (Either a b)) -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m b
right ((Async m b -> m (Either a b)) -> m (Either a b))
-> (Async m b -> m (Either a b)) -> m (Either a b)
forall a b. (a -> b) -> a -> b
$ \Async m b
b ->
                                 Async m a -> Async m b -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m (Either a b)
waitEither Async m a
a Async m b
b

  race_           m a
left m b
right = m a -> (Async m a -> m ()) -> m ()
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
left  ((Async m a -> m ()) -> m ()) -> (Async m a -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async m a
a ->
                               m b -> (Async m b -> m ()) -> m ()
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m b
right ((Async m b -> m ()) -> m ()) -> (Async m b -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async m b
b ->
                                 Async m a -> Async m b -> m ()
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m ()
waitEither_ Async m a
a Async m b
b

  concurrently    m a
left m b
right = m a -> (Async m a -> m (a, b)) -> m (a, b)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m a
left  ((Async m a -> m (a, b)) -> m (a, b))
-> (Async m a -> m (a, b)) -> m (a, b)
forall a b. (a -> b) -> a -> b
$ \Async m a
a ->
                               m b -> (Async m b -> m (a, b)) -> m (a, b)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync m b
right ((Async m b -> m (a, b)) -> m (a, b))
-> (Async m b -> m (a, b)) -> m (a, b)
forall a b. (a -> b) -> a -> b
$ \Async m b
b ->
                                 Async m a -> Async m b -> m (a, b)
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m (a, b)
waitBoth Async m a
a Async m b
b

  concurrently_   m a
left m b
right = m (a, b) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (a, b) -> m ()) -> m (a, b) -> m ()
forall a b. (a -> b) -> a -> b
$ m a -> m b -> m (a, b)
forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m (a, b)
concurrently m a
left m b
right

-- | Similar to 'Async.Concurrently' but which works for any 'MonadAsync'
-- instance.
--
newtype Concurrently m a = Concurrently { Concurrently m a -> m a
runConcurrently :: m a }

instance Functor m => Functor (Concurrently m) where
    fmap :: (a -> b) -> Concurrently m a -> Concurrently m b
fmap a -> b
f (Concurrently m a
ma) = m b -> Concurrently m b
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently ((a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f m a
ma)

instance ( Applicative m
         , MonadAsync m
         ) => Applicative (Concurrently m) where
    pure :: a -> Concurrently m a
pure = m a -> Concurrently m a
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m a -> Concurrently m a) -> (a -> m a) -> a -> Concurrently m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

    Concurrently m (a -> b)
fn <*> :: Concurrently m (a -> b) -> Concurrently m a -> Concurrently m b
<*> Concurrently m a
as =
      m b -> Concurrently m b
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m b -> Concurrently m b) -> m b -> Concurrently m b
forall a b. (a -> b) -> a -> b
$
        (\(a -> b
f, a
a) -> a -> b
f a
a)
        ((a -> b, a) -> b) -> m (a -> b, a) -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap`
        m (a -> b) -> m a -> m (a -> b, a)
forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m (a, b)
concurrently m (a -> b)
fn m a
as

instance ( Alternative m
         , MonadAsync  m
         , MonadTimer  m
         ) => Alternative (Concurrently m) where
    empty :: Concurrently m a
empty = m a -> Concurrently m a
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m a -> Concurrently m a) -> m a -> Concurrently m a
forall a b. (a -> b) -> a -> b
$ m () -> m a
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
86400)
    Concurrently m a
as <|> :: Concurrently m a -> Concurrently m a -> Concurrently m a
<|> Concurrently m a
bs =
      m a -> Concurrently m a
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m a -> Concurrently m a) -> m a -> Concurrently m a
forall a b. (a -> b) -> a -> b
$ (a -> a) -> (a -> a) -> Either a a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either a -> a
forall a. a -> a
id a -> a
forall a. a -> a
id (Either a a -> a) -> m (Either a a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
as m a -> m a -> m (Either a a)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> m b -> m (Either a b)
`race` m a
bs

instance ( Semigroup  a
         , MonadAsync m
         ) => Semigroup (Concurrently m a) where
    <> :: Concurrently m a -> Concurrently m a -> Concurrently m a
(<>) = (a -> a -> a)
-> Concurrently m a -> Concurrently m a -> Concurrently m a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>)

instance ( Monoid a
         , MonadAsync m
         ) => Monoid (Concurrently m a) where
    mempty :: Concurrently m a
mempty = a -> Concurrently m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Monoid a => a
mempty


mapConcurrently :: (Traversable t, MonadAsync m) => (a -> m b) -> t a -> m (t b)
mapConcurrently :: (a -> m b) -> t a -> m (t b)
mapConcurrently a -> m b
f = Concurrently m (t b) -> m (t b)
forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently (Concurrently m (t b) -> m (t b))
-> (t a -> Concurrently m (t b)) -> t a -> m (t b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Concurrently m b) -> t a -> Concurrently m (t b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m b -> Concurrently m b
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m b -> Concurrently m b) -> (a -> m b) -> a -> Concurrently m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f)

forConcurrently :: (Traversable t, MonadAsync m) => t a -> (a -> m b) -> m (t b)
forConcurrently :: t a -> (a -> m b) -> m (t b)
forConcurrently = ((a -> m b) -> t a -> m (t b)) -> t a -> (a -> m b) -> m (t b)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> m b) -> t a -> m (t b)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, MonadAsync m) =>
(a -> m b) -> t a -> m (t b)
mapConcurrently

mapConcurrently_ :: (Foldable f, MonadAsync m) => (a -> m b) -> f a -> m ()
mapConcurrently_ :: (a -> m b) -> f a -> m ()
mapConcurrently_ a -> m b
f = Concurrently m () -> m ()
forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently (Concurrently m () -> m ())
-> (f a -> Concurrently m ()) -> f a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Concurrently m ()) -> f a -> Concurrently m ()
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (m () -> Concurrently m ()
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m () -> Concurrently m ())
-> (a -> m ()) -> a -> Concurrently m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m b -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m b -> m ()) -> (a -> m b) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f)

forConcurrently_ :: (Foldable f, MonadAsync m) => f a -> (a -> m b) -> m ()
forConcurrently_ :: f a -> (a -> m b) -> m ()
forConcurrently_ = ((a -> m b) -> f a -> m ()) -> f a -> (a -> m b) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (a -> m b) -> f a -> m ()
forall (f :: * -> *) (m :: * -> *) a b.
(Foldable f, MonadAsync m) =>
(a -> m b) -> f a -> m ()
mapConcurrently_

replicateConcurrently :: MonadAsync m => Int -> m a -> m [a]
replicateConcurrently :: Int -> m a -> m [a]
replicateConcurrently Int
cnt = Concurrently m [a] -> m [a]
forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently (Concurrently m [a] -> m [a])
-> (m a -> Concurrently m [a]) -> m a -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Concurrently m a] -> Concurrently m [a]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA ([Concurrently m a] -> Concurrently m [a])
-> (m a -> [Concurrently m a]) -> m a -> Concurrently m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Concurrently m a -> [Concurrently m a]
forall a. Int -> a -> [a]
replicate Int
cnt (Concurrently m a -> [Concurrently m a])
-> (m a -> Concurrently m a) -> m a -> [Concurrently m a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Concurrently m a
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently

replicateConcurrently_ :: MonadAsync m => Int -> m a -> m ()
replicateConcurrently_ :: Int -> m a -> m ()
replicateConcurrently_ Int
cnt = Concurrently m () -> m ()
forall (m :: * -> *) a. Concurrently m a -> m a
runConcurrently (Concurrently m () -> m ())
-> (m a -> Concurrently m ()) -> m a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Concurrently m ()] -> Concurrently m ()
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold ([Concurrently m ()] -> Concurrently m ())
-> (m a -> [Concurrently m ()]) -> m a -> Concurrently m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Concurrently m () -> [Concurrently m ()]
forall a. Int -> a -> [a]
replicate Int
cnt (Concurrently m () -> [Concurrently m ()])
-> (m a -> Concurrently m ()) -> m a -> [Concurrently m ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m () -> Concurrently m ()
forall (m :: * -> *) a. m a -> Concurrently m a
Concurrently (m () -> Concurrently m ())
-> (m a -> m ()) -> m a -> Concurrently m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void


--
-- Instance for IO uses the existing async library implementations
--

instance MonadAsyncSTM Async.Async STM.STM where
  waitSTM :: Async a -> STM a
waitSTM            = Async a -> STM a
forall a. Async a -> STM a
Async.waitSTM
  pollSTM :: Async a -> STM (Maybe (Either SomeException a))
pollSTM            = Async a -> STM (Maybe (Either SomeException a))
forall a. Async a -> STM (Maybe (Either SomeException a))
Async.pollSTM
  waitCatchSTM :: Async a -> STM (Either SomeException a)
waitCatchSTM       = Async a -> STM (Either SomeException a)
forall a. Async a -> STM (Either SomeException a)
Async.waitCatchSTM
  waitAnySTM :: [Async a] -> STM (Async a, a)
waitAnySTM         = [Async a] -> STM (Async a, a)
forall a. [Async a] -> STM (Async a, a)
Async.waitAnySTM
  waitAnyCatchSTM :: [Async a] -> STM (Async a, Either SomeException a)
waitAnyCatchSTM    = [Async a] -> STM (Async a, Either SomeException a)
forall a. [Async a] -> STM (Async a, Either SomeException a)
Async.waitAnyCatchSTM
  waitEitherSTM :: Async a -> Async b -> STM (Either a b)
waitEitherSTM      = Async a -> Async b -> STM (Either a b)
forall a b. Async a -> Async b -> STM (Either a b)
Async.waitEitherSTM
  waitEitherSTM_ :: Async a -> Async b -> STM ()
waitEitherSTM_     = Async a -> Async b -> STM ()
forall a b. Async a -> Async b -> STM ()
Async.waitEitherSTM_
  waitEitherCatchSTM :: Async a
-> Async b
-> STM (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchSTM = Async a
-> Async b
-> STM (Either (Either SomeException a) (Either SomeException b))
forall a b.
Async a
-> Async b
-> STM (Either (Either SomeException a) (Either SomeException b))
Async.waitEitherCatchSTM
  waitBothSTM :: Async a -> Async b -> STM (a, b)
waitBothSTM        = Async a -> Async b -> STM (a, b)
forall a b. Async a -> Async b -> STM (a, b)
Async.waitBothSTM

instance MonadAsync IO where

  type Async IO         = Async.Async

  async :: IO a -> IO (Async IO a)
async                 = IO a -> IO (Async IO a)
forall a. IO a -> IO (Async a)
Async.async
  asyncThreadId :: Proxy IO -> Async IO a -> ThreadId IO
asyncThreadId         = \Proxy IO
_proxy -> Async IO a -> ThreadId IO
forall a. Async a -> ThreadId
Async.asyncThreadId
  withAsync :: IO a -> (Async IO a -> IO b) -> IO b
withAsync             = IO a -> (Async IO a -> IO b) -> IO b
forall a b. IO a -> (Async a -> IO b) -> IO b
Async.withAsync

  wait :: Async IO a -> IO a
wait                  = Async IO a -> IO a
forall a. Async a -> IO a
Async.wait
  poll :: Async IO a -> IO (Maybe (Either SomeException a))
poll                  = Async IO a -> IO (Maybe (Either SomeException a))
forall a. Async a -> IO (Maybe (Either SomeException a))
Async.poll
  waitCatch :: Async IO a -> IO (Either SomeException a)
waitCatch             = Async IO a -> IO (Either SomeException a)
forall a. Async a -> IO (Either SomeException a)
Async.waitCatch
  cancel :: Async IO a -> IO ()
cancel                = Async IO a -> IO ()
forall a. Async a -> IO ()
Async.cancel
  cancelWith :: Async IO a -> e -> IO ()
cancelWith            = Async IO a -> e -> IO ()
forall e a. Exception e => Async a -> e -> IO ()
Async.cancelWith
  uninterruptibleCancel :: Async IO a -> IO ()
uninterruptibleCancel = Async IO a -> IO ()
forall a. Async a -> IO ()
Async.uninterruptibleCancel

  waitAny :: [Async IO a] -> IO (Async IO a, a)
waitAny               = [Async IO a] -> IO (Async IO a, a)
forall a. [Async a] -> IO (Async a, a)
Async.waitAny
  waitAnyCatch :: [Async IO a] -> IO (Async IO a, Either SomeException a)
waitAnyCatch          = [Async IO a] -> IO (Async IO a, Either SomeException a)
forall a. [Async a] -> IO (Async a, Either SomeException a)
Async.waitAnyCatch
  waitAnyCancel :: [Async IO a] -> IO (Async IO a, a)
waitAnyCancel         = [Async IO a] -> IO (Async IO a, a)
forall a. [Async a] -> IO (Async a, a)
Async.waitAnyCancel
  waitAnyCatchCancel :: [Async IO a] -> IO (Async IO a, Either SomeException a)
waitAnyCatchCancel    = [Async IO a] -> IO (Async IO a, Either SomeException a)
forall a. [Async a] -> IO (Async a, Either SomeException a)
Async.waitAnyCatchCancel
  waitEither :: Async IO a -> Async IO b -> IO (Either a b)
waitEither            = Async IO a -> Async IO b -> IO (Either a b)
forall a b. Async a -> Async b -> IO (Either a b)
Async.waitEither
  waitEitherCatch :: Async IO a
-> Async IO b
-> IO (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch       = Async IO a
-> Async IO b
-> IO (Either (Either SomeException a) (Either SomeException b))
forall a b.
Async a
-> Async b
-> IO (Either (Either SomeException a) (Either SomeException b))
Async.waitEitherCatch
  waitEitherCancel :: Async IO a -> Async IO b -> IO (Either a b)
waitEitherCancel      = Async IO a -> Async IO b -> IO (Either a b)
forall a b. Async a -> Async b -> IO (Either a b)
Async.waitEitherCancel
  waitEitherCatchCancel :: Async IO a
-> Async IO b
-> IO (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchCancel = Async IO a
-> Async IO b
-> IO (Either (Either SomeException a) (Either SomeException b))
forall a b.
Async a
-> Async b
-> IO (Either (Either SomeException a) (Either SomeException b))
Async.waitEitherCatchCancel
  waitEither_ :: Async IO a -> Async IO b -> IO ()
waitEither_           = Async IO a -> Async IO b -> IO ()
forall a b. Async a -> Async b -> IO ()
Async.waitEither_
  waitBoth :: Async IO a -> Async IO b -> IO (a, b)
waitBoth              = Async IO a -> Async IO b -> IO (a, b)
forall a b. Async a -> Async b -> IO (a, b)
Async.waitBoth

  race :: IO a -> IO b -> IO (Either a b)
race                  = IO a -> IO b -> IO (Either a b)
forall a b. IO a -> IO b -> IO (Either a b)
Async.race
  race_ :: IO a -> IO b -> IO ()
race_                 = IO a -> IO b -> IO ()
forall a b. IO a -> IO b -> IO ()
Async.race_
  concurrently :: IO a -> IO b -> IO (a, b)
concurrently          = IO a -> IO b -> IO (a, b)
forall a b. IO a -> IO b -> IO (a, b)
Async.concurrently
  concurrently_ :: IO a -> IO b -> IO ()
concurrently_         = IO a -> IO b -> IO ()
forall a b. IO a -> IO b -> IO ()
Async.concurrently_

  asyncWithUnmask :: ((forall b. IO b -> IO b) -> IO a) -> IO (Async IO a)
asyncWithUnmask       = ((forall b. IO b -> IO b) -> IO a) -> IO (Async IO a)
forall a. ((forall b. IO b -> IO b) -> IO a) -> IO (Async a)
Async.asyncWithUnmask

--
-- Lift to ReaderT
--

(.:) :: (c -> d) -> (a -> b -> c) -> (a -> b -> d)
(c -> d
f .: :: (c -> d) -> (a -> b -> c) -> a -> b -> d
.: a -> b -> c
g) a
x b
y = c -> d
f (a -> b -> c
g a
x b
y)

instance MonadAsync m => MonadAsync (ReaderT r m) where
  type Async (ReaderT r m) = Async m

  asyncThreadId :: Proxy (ReaderT r m)
-> Async (ReaderT r m) a -> ThreadId (ReaderT r m)
asyncThreadId Proxy (ReaderT r m)
_ = Proxy m -> Async m a -> ThreadId m
forall (m :: * -> *) a.
MonadAsync m =>
Proxy m -> Async m a -> ThreadId m
asyncThreadId (Proxy m
forall k (t :: k). Proxy t
Proxy @m)

  async :: ReaderT r m a -> ReaderT r m (Async (ReaderT r m) a)
async     (ReaderT r -> m a
ma)   = (r -> m (Async m a)) -> ReaderT r m (Async m a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Async m a)) -> ReaderT r m (Async m a))
-> (r -> m (Async m a)) -> ReaderT r m (Async m a)
forall a b. (a -> b) -> a -> b
$ \r
r -> m a -> m (Async m a)
forall (m :: * -> *) a. MonadAsync m => m a -> m (Async m a)
async (r -> m a
ma r
r)
  withAsync :: ReaderT r m a
-> (Async (ReaderT r m) a -> ReaderT r m b) -> ReaderT r m b
withAsync (ReaderT r -> m a
ma) Async (ReaderT r m) a -> ReaderT r m b
f = (r -> m b) -> ReaderT r m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m b) -> ReaderT r m b) -> (r -> m b) -> ReaderT r m b
forall a b. (a -> b) -> a -> b
$ \r
r -> m a -> (Async m a -> m b) -> m b
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> (Async m a -> m b) -> m b
withAsync (r -> m a
ma r
r) ((Async m a -> m b) -> m b) -> (Async m a -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Async m a
a -> ReaderT r m b -> r -> m b
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Async (ReaderT r m) a -> ReaderT r m b
f Async m a
Async (ReaderT r m) a
a) r
r
  asyncWithUnmask :: ((forall b. ReaderT r m b -> ReaderT r m b) -> ReaderT r m a)
-> ReaderT r m (Async (ReaderT r m) a)
asyncWithUnmask        (forall b. ReaderT r m b -> ReaderT r m b) -> ReaderT r m a
f = (r -> m (Async m a)) -> ReaderT r m (Async m a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Async m a)) -> ReaderT r m (Async m a))
-> (r -> m (Async m a)) -> ReaderT r m (Async m a)
forall a b. (a -> b) -> a -> b
$ \r
r ->
                              ((forall b. m b -> m b) -> m a) -> m (Async m a)
forall (m :: * -> *) a.
MonadAsync m =>
((forall a. m a -> m a) -> m a) -> m (Async m a)
asyncWithUnmask (((forall b. m b -> m b) -> m a) -> m (Async m a))
-> ((forall b. m b -> m b) -> m a) -> m (Async m a)
forall a b. (a -> b) -> a -> b
$ \forall b. m b -> m b
unmask ->
                                ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall b. ReaderT r m b -> ReaderT r m b) -> ReaderT r m a
f ((m b -> m b) -> ReaderT r m b -> ReaderT r m b
forall a. (m a -> m a) -> ReaderT r m a -> ReaderT r m a
liftF m b -> m b
forall b. m b -> m b
unmask)) r
r
    where
      liftF :: (m a -> m a) ->  ReaderT r m a -> ReaderT r m a
      liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
liftF m a -> m a
g (ReaderT r -> m a
r) = (r -> m a) -> ReaderT r m a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT (m a -> m a
g (m a -> m a) -> (r -> m a) -> r -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. r -> m a
r)

  race :: ReaderT r m a -> ReaderT r m b -> ReaderT r m (Either a b)
race         (ReaderT r -> m a
ma) (ReaderT r -> m b
mb) = (r -> m (Either a b)) -> ReaderT r m (Either a b)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Either a b)) -> ReaderT r m (Either a b))
-> (r -> m (Either a b)) -> ReaderT r m (Either a b)
forall a b. (a -> b) -> a -> b
$ \r
r -> m a -> m b -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
m a -> m b -> m (Either a b)
race         (r -> m a
ma r
r) (r -> m b
mb r
r)
  race_ :: ReaderT r m a -> ReaderT r m b -> ReaderT r m ()
race_        (ReaderT r -> m a
ma) (ReaderT r -> m b
mb) = (r -> m ()) -> ReaderT r m ()
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m ()) -> ReaderT r m ()) -> (r -> m ()) -> ReaderT r m ()
forall a b. (a -> b) -> a -> b
$ \r
r -> m a -> m b -> m ()
forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m ()
race_        (r -> m a
ma r
r) (r -> m b
mb r
r)
  concurrently :: ReaderT r m a -> ReaderT r m b -> ReaderT r m (a, b)
concurrently (ReaderT r -> m a
ma) (ReaderT r -> m b
mb) = (r -> m (a, b)) -> ReaderT r m (a, b)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (a, b)) -> ReaderT r m (a, b))
-> (r -> m (a, b)) -> ReaderT r m (a, b)
forall a b. (a -> b) -> a -> b
$ \r
r -> m a -> m b -> m (a, b)
forall (m :: * -> *) a b. MonadAsync m => m a -> m b -> m (a, b)
concurrently (r -> m a
ma r
r) (r -> m b
mb r
r)

  wait :: Async (ReaderT r m) a -> ReaderT r m a
wait                  = m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ReaderT r m a)
-> (Async m a -> m a) -> Async m a -> ReaderT r m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Async m a -> m a
forall (m :: * -> *) a. MonadAsync m => Async m a -> m a
wait
  poll :: Async (ReaderT r m) a
-> ReaderT r m (Maybe (Either SomeException a))
poll                  = m (Maybe (Either SomeException a))
-> ReaderT r m (Maybe (Either SomeException a))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Maybe (Either SomeException a))
 -> ReaderT r m (Maybe (Either SomeException a)))
-> (Async m a -> m (Maybe (Either SomeException a)))
-> Async m a
-> ReaderT r m (Maybe (Either SomeException a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Async m a -> m (Maybe (Either SomeException a))
forall (m :: * -> *) a.
MonadAsync m =>
Async m a -> m (Maybe (Either SomeException a))
poll
  waitCatch :: Async (ReaderT r m) a -> ReaderT r m (Either SomeException a)
waitCatch             = m (Either SomeException a) -> ReaderT r m (Either SomeException a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either SomeException a)
 -> ReaderT r m (Either SomeException a))
-> (Async m a -> m (Either SomeException a))
-> Async m a
-> ReaderT r m (Either SomeException a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Async m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
Async m a -> m (Either SomeException a)
waitCatch
  cancel :: Async (ReaderT r m) a -> ReaderT r m ()
cancel                = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (Async m a -> m ()) -> Async m a -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel
  uninterruptibleCancel :: Async (ReaderT r m) a -> ReaderT r m ()
uninterruptibleCancel = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (Async m a -> m ()) -> Async m a -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Async m a -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
uninterruptibleCancel
  cancelWith :: Async (ReaderT r m) a -> e -> ReaderT r m ()
cancelWith            = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (Async m a -> e -> m ()) -> Async m a -> e -> ReaderT r m ()
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a -> e -> m ()
forall (m :: * -> *) e a.
(MonadAsync m, Exception e) =>
Async m a -> e -> m ()
cancelWith
  waitAny :: [Async (ReaderT r m) a] -> ReaderT r m (Async (ReaderT r m) a, a)
waitAny               = m (Async m a, a) -> ReaderT r m (Async m a, a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Async m a, a) -> ReaderT r m (Async m a, a))
-> ([Async m a] -> m (Async m a, a))
-> [Async m a]
-> ReaderT r m (Async m a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  [Async m a] -> m (Async m a, a)
forall (m :: * -> *) a.
MonadAsync m =>
[Async m a] -> m (Async m a, a)
waitAny
  waitAnyCatch :: [Async (ReaderT r m) a]
-> ReaderT r m (Async (ReaderT r m) a, Either SomeException a)
waitAnyCatch          = m (Async m a, Either SomeException a)
-> ReaderT r m (Async m a, Either SomeException a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Async m a, Either SomeException a)
 -> ReaderT r m (Async m a, Either SomeException a))
-> ([Async m a] -> m (Async m a, Either SomeException a))
-> [Async m a]
-> ReaderT r m (Async m a, Either SomeException a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  [Async m a] -> m (Async m a, Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
[Async m a] -> m (Async m a, Either SomeException a)
waitAnyCatch
  waitAnyCancel :: [Async (ReaderT r m) a] -> ReaderT r m (Async (ReaderT r m) a, a)
waitAnyCancel         = m (Async m a, a) -> ReaderT r m (Async m a, a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Async m a, a) -> ReaderT r m (Async m a, a))
-> ([Async m a] -> m (Async m a, a))
-> [Async m a]
-> ReaderT r m (Async m a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  [Async m a] -> m (Async m a, a)
forall (m :: * -> *) a.
MonadAsync m =>
[Async m a] -> m (Async m a, a)
waitAnyCancel
  waitAnyCatchCancel :: [Async (ReaderT r m) a]
-> ReaderT r m (Async (ReaderT r m) a, Either SomeException a)
waitAnyCatchCancel    = m (Async m a, Either SomeException a)
-> ReaderT r m (Async m a, Either SomeException a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Async m a, Either SomeException a)
 -> ReaderT r m (Async m a, Either SomeException a))
-> ([Async m a] -> m (Async m a, Either SomeException a))
-> [Async m a]
-> ReaderT r m (Async m a, Either SomeException a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  [Async m a] -> m (Async m a, Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
[Async m a] -> m (Async m a, Either SomeException a)
waitAnyCatchCancel
  waitEither :: Async (ReaderT r m) a
-> Async (ReaderT r m) b -> ReaderT r m (Either a b)
waitEither            = m (Either a b) -> ReaderT r m (Either a b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either a b) -> ReaderT r m (Either a b))
-> (Async m a -> Async m b -> m (Either a b))
-> Async m a
-> Async m b
-> ReaderT r m (Either a b)
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a -> Async m b -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m (Either a b)
waitEither
  waitEitherCatch :: Async (ReaderT r m) a
-> Async (ReaderT r m) b
-> ReaderT
     r m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch       = m (Either (Either SomeException a) (Either SomeException b))
-> ReaderT
     r m (Either (Either SomeException a) (Either SomeException b))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either (Either SomeException a) (Either SomeException b))
 -> ReaderT
      r m (Either (Either SomeException a) (Either SomeException b)))
-> (Async m a
    -> Async m b
    -> m (Either (Either SomeException a) (Either SomeException b)))
-> Async m a
-> Async m b
-> ReaderT
     r m (Either (Either SomeException a) (Either SomeException b))
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a
-> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a
-> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch
  waitEitherCancel :: Async (ReaderT r m) a
-> Async (ReaderT r m) b -> ReaderT r m (Either a b)
waitEitherCancel      = m (Either a b) -> ReaderT r m (Either a b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either a b) -> ReaderT r m (Either a b))
-> (Async m a -> Async m b -> m (Either a b))
-> Async m a
-> Async m b
-> ReaderT r m (Either a b)
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a -> Async m b -> m (Either a b)
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m (Either a b)
waitEitherCancel
  waitEitherCatchCancel :: Async (ReaderT r m) a
-> Async (ReaderT r m) b
-> ReaderT
     r m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchCancel = m (Either (Either SomeException a) (Either SomeException b))
-> ReaderT
     r m (Either (Either SomeException a) (Either SomeException b))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Either (Either SomeException a) (Either SomeException b))
 -> ReaderT
      r m (Either (Either SomeException a) (Either SomeException b)))
-> (Async m a
    -> Async m b
    -> m (Either (Either SomeException a) (Either SomeException b)))
-> Async m a
-> Async m b
-> ReaderT
     r m (Either (Either SomeException a) (Either SomeException b))
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a
-> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a
-> Async m b
-> m (Either (Either SomeException a) (Either SomeException b))
waitEitherCatchCancel
  waitEither_ :: Async (ReaderT r m) a -> Async (ReaderT r m) b -> ReaderT r m ()
waitEither_           = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (Async m a -> Async m b -> m ())
-> Async m a
-> Async m b
-> ReaderT r m ()
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a -> Async m b -> m ()
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m ()
waitEither_
  waitBoth :: Async (ReaderT r m) a
-> Async (ReaderT r m) b -> ReaderT r m (a, b)
waitBoth              = m (a, b) -> ReaderT r m (a, b)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (a, b) -> ReaderT r m (a, b))
-> (Async m a -> Async m b -> m (a, b))
-> Async m a
-> Async m b
-> ReaderT r m (a, b)
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: Async m a -> Async m b -> m (a, b)
forall (m :: * -> *) a b.
MonadAsync m =>
Async m a -> Async m b -> m (a, b)
waitBoth

--
-- Linking
--
-- Adapted from "Control.Concurrent.Async"
--
-- We don't use the implementation of linking from 'Control.Concurrent.Async'
-- directly because:
--
-- 1. We need a generalized form of linking that links an async to an arbitrary
--    thread ('linkTo')
-- 2. If we /did/ use the real implementation, then the mock implementation and
--    the real implementation would not be able to throw the same exception,
--    because the exception type used by the real implementation is
--
-- > data ExceptionInLinkedThread =
-- >   forall a . ExceptionInLinkedThread (Async a) SomeException
--
--    containing a reference to the real 'Async' type.
--

-- | Exception from child thread re-raised in parent thread
--
-- We record the thread ID of the child thread as a 'String'. This avoids
-- an @m@ parameter in the type, which is important: 'ExceptionInLinkedThread'
-- must be an instance of 'Exception', requiring it to be 'Typeable'; if @m@
-- appeared in the type, we would require @m@ to be 'Typeable', which does not
-- work with with the simulator, as it would require a 'Typeable' constraint
-- on the @s@ parameter of 'IOSim'.
data ExceptionInLinkedThread = ExceptionInLinkedThread String SomeException

instance Show ExceptionInLinkedThread where
  showsPrec :: Int -> ExceptionInLinkedThread -> ShowS
showsPrec Int
p (ExceptionInLinkedThread String
a SomeException
e) =
    Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
11) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
      String -> ShowS
showString String
"ExceptionInLinkedThread " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
      Int -> String -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 String
a ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
      String -> ShowS
showString String
" " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
      Int -> SomeException -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 SomeException
e

instance Exception ExceptionInLinkedThread where
  fromException :: SomeException -> Maybe ExceptionInLinkedThread
fromException = SomeException -> Maybe ExceptionInLinkedThread
forall e. Exception e => SomeException -> Maybe e
E.asyncExceptionFromException
  toException :: ExceptionInLinkedThread -> SomeException
toException = ExceptionInLinkedThread -> SomeException
forall e. Exception e => e -> SomeException
E.asyncExceptionToException

-- | Generalizion of 'link' that links an async to an arbitrary thread.
linkTo :: (MonadAsync m, MonadFork m, MonadMask m)
       => ThreadId m -> Async m a -> m ()
linkTo :: ThreadId m -> Async m a -> m ()
linkTo ThreadId m
tid = ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid (Bool -> Bool
not (Bool -> Bool) -> (SomeException -> Bool) -> SomeException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Bool
isCancel)

linkToOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
           => ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly :: ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
tid SomeException -> Bool
shouldThrow Async m a
a = do
    m (ThreadId m) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (ThreadId m) -> m ()) -> m (ThreadId m) -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m () -> m (ThreadId m)
forall (m :: * -> *) a.
(MonadFork m, MonadMask m) =>
String -> m a -> m (ThreadId m)
forkRepeat (String
"linkToOnly " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId) (m () -> m (ThreadId m)) -> m () -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ do
      Either SomeException a
r <- Async m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadAsync m =>
Async m a -> m (Either SomeException a)
waitCatch Async m a
a
      case Either SomeException a
r of
        Left SomeException
e | SomeException -> Bool
shouldThrow SomeException
e -> ThreadId m -> ExceptionInLinkedThread -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid (SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread SomeException
e)
        Either SomeException a
_otherwise             -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    linkedThreadId :: ThreadId m
    linkedThreadId :: ThreadId m
linkedThreadId = Proxy m -> Async m a -> ThreadId m
forall (m :: * -> *) a.
MonadAsync m =>
Proxy m -> Async m a -> ThreadId m
asyncThreadId (Proxy m
forall k (t :: k). Proxy t
Proxy @m) Async m a
a

    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
    exceptionInLinkedThread :: SomeException -> ExceptionInLinkedThread
exceptionInLinkedThread =
        String -> SomeException -> ExceptionInLinkedThread
ExceptionInLinkedThread (ThreadId m -> String
forall a. Show a => a -> String
show ThreadId m
linkedThreadId)

link :: (MonadAsync m, MonadFork m, MonadMask m)
     => Async m a -> m ()
link :: Async m a -> m ()
link = (SomeException -> Bool) -> Async m a -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
(SomeException -> Bool) -> Async m a -> m ()
linkOnly (Bool -> Bool
not (Bool -> Bool) -> (SomeException -> Bool) -> SomeException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Bool
isCancel)

linkOnly :: forall m a. (MonadAsync m, MonadFork m, MonadMask m)
         => (SomeException -> Bool) -> Async m a -> m ()
linkOnly :: (SomeException -> Bool) -> Async m a -> m ()
linkOnly SomeException -> Bool
shouldThrow Async m a
a = do
    ThreadId m
me <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
    ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
forall (m :: * -> *) a.
(MonadAsync m, MonadFork m, MonadMask m) =>
ThreadId m -> (SomeException -> Bool) -> Async m a -> m ()
linkToOnly ThreadId m
me SomeException -> Bool
shouldThrow Async m a
a

isCancel :: SomeException -> Bool
isCancel :: SomeException -> Bool
isCancel SomeException
e
  | Just AsyncCancelled
AsyncCancelled <- SomeException -> Maybe AsyncCancelled
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e = Bool
True
  | Bool
otherwise = Bool
False

forkRepeat :: (MonadFork m, MonadMask m) => String -> m a -> m (ThreadId m)
forkRepeat :: String -> m a -> m (ThreadId m)
forkRepeat String
label m a
action =
  ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m (ThreadId m)) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
    let go :: m ()
go = do Either SomeException a
r <- m a -> m (Either SomeException a)
forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAll (m a -> m a
forall a. m a -> m a
restore m a
action)
                case Either SomeException a
r of
                  Left SomeException
_ -> m ()
go
                  Either SomeException a
_      -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    in m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (String -> m ()
forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m ()
go)

tryAll :: MonadCatch m => m a -> m (Either SomeException a)
tryAll :: m a -> m (Either SomeException a)
tryAll = m a -> m (Either SomeException a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try