{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}


module Control.Concurrent.JobPool (
    JobPool,
    Job(..),
    withJobPool,
    forkJob,
    readSize,
    collect
  ) where

import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import           Data.Proxy (Proxy (..))

import           Control.Exception (SomeAsyncException (..))
import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadFork (MonadThread (..))
import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadThrow


data JobPool m a = JobPool {
       JobPool m a -> TVar m (Map (ThreadId m) (Async m ()))
jobsVar         :: !(TVar m (Map (ThreadId m) (Async m ()))),
       JobPool m a -> TQueue m a
completionQueue :: !(TQueue m a)
     }

data Job m a = Job (m a) (SomeException -> a) String

withJobPool :: forall m a b.
               (MonadAsync m, MonadThrow m)
            => (JobPool m a -> m b) -> m b
withJobPool :: (JobPool m a -> m b) -> m b
withJobPool =
    m (JobPool m a)
-> (JobPool m a -> m ()) -> (JobPool m a -> m b) -> m b
forall (m :: * -> *) a b c.
MonadThrow m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket m (JobPool m a)
create JobPool m a -> m ()
close
  where
    create :: m (JobPool m a)
    create :: m (JobPool m a)
create =
      STM m (JobPool m a) -> m (JobPool m a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (JobPool m a) -> m (JobPool m a))
-> STM m (JobPool m a) -> m (JobPool m a)
forall a b. (a -> b) -> a -> b
$
        TVar_ (STM m) (Map (ThreadId m) (Async m ()))
-> TQueue_ (STM m) a -> JobPool m a
forall (m :: * -> *) a.
TVar m (Map (ThreadId m) (Async m ())) -> TQueue m a -> JobPool m a
JobPool (TVar_ (STM m) (Map (ThreadId m) (Async m ()))
 -> TQueue_ (STM m) a -> JobPool m a)
-> STM m (TVar_ (STM m) (Map (ThreadId m) (Async m ())))
-> STM m (TQueue_ (STM m) a -> JobPool m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map (ThreadId m) (Async m ())
-> STM m (TVar_ (STM m) (Map (ThreadId m) (Async m ())))
forall (stm :: * -> *) a. MonadSTMTx stm => a -> stm (TVar_ stm a)
newTVar Map (ThreadId m) (Async m ())
forall k a. Map k a
Map.empty
                STM m (TQueue_ (STM m) a -> JobPool m a)
-> STM m (TQueue_ (STM m) a) -> STM m (JobPool m a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> STM m (TQueue_ (STM m) a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm (TQueue_ stm a)
newTQueue

    close :: JobPool m a -> m ()
    close :: JobPool m a -> m ()
close JobPool{TVar_ (STM m) (Map (ThreadId m) (Async m ()))
jobsVar :: TVar_ (STM m) (Map (ThreadId m) (Async m ()))
jobsVar :: forall (m :: * -> *) a.
JobPool m a -> TVar m (Map (ThreadId m) (Async m ()))
jobsVar} = do
      Map (ThreadId m) (Async m ())
jobs <- STM m (Map (ThreadId m) (Async m ()))
-> m (Map (ThreadId m) (Async m ()))
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TVar_ (STM m) (Map (ThreadId m) (Async m ()))
-> STM m (Map (ThreadId m) (Async m ()))
forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> stm a
readTVar TVar_ (STM m) (Map (ThreadId m) (Async m ()))
jobsVar)
      (Async m () -> m ()) -> Map (ThreadId m) (Async m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Async m () -> m ()
forall (m :: * -> *) a. MonadAsync m => Async m a -> m ()
cancel Map (ThreadId m) (Async m ())
jobs

forkJob :: forall m a.
           (MonadAsync m, MonadMask m)
        => JobPool m a
        -> Job     m a
        -> m ()
forkJob :: JobPool m a -> Job m a -> m ()
forkJob JobPool{TVar m (Map (ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (ThreadId m) (Async m ()))
jobsVar :: forall (m :: * -> *) a.
JobPool m a -> TVar m (Map (ThreadId m) (Async m ()))
jobsVar, TQueue m a
completionQueue :: TQueue m a
completionQueue :: forall (m :: * -> *) a. JobPool m a -> TQueue m a
completionQueue} (Job m a
action SomeException -> a
handler String
label) =
    ((forall a. m a -> m a) -> m ()) -> m ()
forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
mask (((forall a. m a -> m a) -> m ()) -> m ())
-> ((forall a. m a -> m a) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore -> do
      Async m ()
jobAsync <- m () -> m (Async m ())
forall (m :: * -> *) a. MonadAsync m => m a -> m (Async m a)
async (m () -> m (Async m ())) -> m () -> m (Async m ())
forall a b. (a -> b) -> a -> b
$ do
        ThreadId m
tid <- m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
        ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label
        !a
res <- (SomeException -> Maybe SomeException)
-> (SomeException -> m a) -> m a -> m a
forall (m :: * -> *) e b a.
(MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust SomeException -> Maybe SomeException
notAsyncExceptions (a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> (SomeException -> a) -> SomeException -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> a
handler) (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
                 m a -> m a
forall a. m a -> m a
restore m a
action
        STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          TQueue m a -> a -> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TQueue_ stm a -> a -> stm ()
writeTQueue TQueue m a
completionQueue a
res
          TVar m (Map (ThreadId m) (Async m ()))
-> (Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ()))
-> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TVar_ stm a -> (a -> a) -> stm ()
modifyTVar' TVar m (Map (ThreadId m) (Async m ()))
jobsVar (ThreadId m
-> Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ())
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ThreadId m
tid)

      let !tid :: ThreadId m
tid = Proxy m -> Async m () -> ThreadId m
forall (m :: * -> *) a.
MonadAsync m =>
Proxy m -> Async m a -> ThreadId m
asyncThreadId (Proxy m
forall k (t :: k). Proxy t
Proxy :: Proxy m) Async m ()
jobAsync
      STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m () -> m ()) -> STM m () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar m (Map (ThreadId m) (Async m ()))
-> (Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ()))
-> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TVar_ stm a -> (a -> a) -> stm ()
modifyTVar' TVar m (Map (ThreadId m) (Async m ()))
jobsVar (ThreadId m
-> Async m ()
-> Map (ThreadId m) (Async m ())
-> Map (ThreadId m) (Async m ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ThreadId m
tid Async m ()
jobAsync)
  where
    notAsyncExceptions :: SomeException -> Maybe SomeException
    notAsyncExceptions :: SomeException -> Maybe SomeException
notAsyncExceptions SomeException
e
      | Just (SomeAsyncException e
_) <- SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
                  = Maybe SomeException
forall a. Maybe a
Nothing
      | Bool
otherwise = SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
e

readSize :: MonadSTM m => JobPool m a -> STM m Int
readSize :: JobPool m a -> STM m Int
readSize JobPool{TVar m (Map (ThreadId m) (Async m ()))
jobsVar :: TVar m (Map (ThreadId m) (Async m ()))
jobsVar :: forall (m :: * -> *) a.
JobPool m a -> TVar m (Map (ThreadId m) (Async m ()))
jobsVar} = Map (ThreadId m) (Async m ()) -> Int
forall k a. Map k a -> Int
Map.size (Map (ThreadId m) (Async m ()) -> Int)
-> STM m (Map (ThreadId m) (Async m ())) -> STM m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar m (Map (ThreadId m) (Async m ()))
-> STM m (Map (ThreadId m) (Async m ()))
forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> stm a
readTVar TVar m (Map (ThreadId m) (Async m ()))
jobsVar

collect :: MonadSTM m => JobPool m a -> STM m a
collect :: JobPool m a -> STM m a
collect JobPool{TQueue m a
completionQueue :: TQueue m a
completionQueue :: forall (m :: * -> *) a. JobPool m a -> TQueue m a
completionQueue} = TQueue m a -> STM m a
forall (stm :: * -> *) a. MonadSTMTx stm => TQueue_ stm a -> stm a
readTQueue TQueue m a
completionQueue