{-# LANGUAGE BangPatterns #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} module Control.Concurrent.JobPool ( JobPool, Job(..), withJobPool, forkJob, readSize, collect ) where import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Data.Proxy (Proxy (..)) import Control.Exception (SomeAsyncException (..)) import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadFork (MonadThread (..)) import Control.Monad.Class.MonadSTM import Control.Monad.Class.MonadThrow data JobPool m a = JobPool { JobPool m a -> TVar m (Map (ThreadId m) (Async m ())) jobsVar :: !(TVar m (Map (ThreadId m) (Async m ()))), JobPool m a -> TQueue m a completionQueue :: !(TQueue m a) } data Job m a = Job (m a) (SomeException -> a) String withJobPool :: forall m a b. (MonadAsync m, MonadThrow m) => (JobPool m a -> m b) -> m b withJobPool :: (JobPool m a -> m b) -> m b withJobPool = m (JobPool m a) -> (JobPool m a -> m ()) -> (JobPool m a -> m b) -> m b forall (m :: * -> *) a b c. MonadThrow m => m a -> (a -> m b) -> (a -> m c) -> m c bracket m (JobPool m a) create JobPool m a -> m () close where create :: m (JobPool m a) create :: m (JobPool m a) create = STM m (JobPool m a) -> m (JobPool m a) forall (m :: * -> *) a. (MonadSTM m, HasCallStack) => STM m a -> m a atomically (STM m (JobPool m a) -> m (JobPool m a)) -> STM m (JobPool m a) -> m (JobPool m a) forall a b. (a -> b) -> a -> b $ TVar_ (STM m) (Map (ThreadId m) (Async m ())) -> TQueue_ (STM m) a -> JobPool m a forall (m :: * -> *) a. TVar m (Map (ThreadId m) (Async m ())) -> TQueue m a -> JobPool m a JobPool (TVar_ (STM m) (Map (ThreadId m) (Async m ())) -> TQueue_ (STM m) a -> JobPool m a) -> STM m (TVar_ (STM m) (Map (ThreadId m) (Async m ()))) -> STM m (TQueue_ (STM m) a -> JobPool m a) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Map (ThreadId m) (Async m ()) -> STM m (TVar_ (STM m) (Map (ThreadId m) (Async m ()))) forall (stm :: * -> *) a. MonadSTMTx stm => a -> stm (TVar_ stm a) newTVar Map (ThreadId m) (Async m ()) forall k a. Map k a Map.empty STM m (TQueue_ (STM m) a -> JobPool m a) -> STM m (TQueue_ (STM m) a) -> STM m (JobPool m a) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> STM m (TQueue_ (STM m) a) forall (stm :: * -> *) a. MonadSTMTx stm => stm (TQueue_ stm a) newTQueue close :: JobPool m a -> m () close :: JobPool m a -> m () close JobPool{TVar_ (STM m) (Map (ThreadId m) (Async m ())) jobsVar :: TVar_ (STM m) (Map (ThreadId m) (Async m ())) jobsVar :: forall (m :: * -> *) a. JobPool m a -> TVar m (Map (ThreadId m) (Async m ())) jobsVar} = do Map (ThreadId m) (Async m ()) jobs <- STM m (Map (ThreadId m) (Async m ())) -> m (Map (ThreadId m) (Async m ())) forall (m :: * -> *) a. (MonadSTM m, HasCallStack) => STM m a -> m a atomically (TVar_ (STM m) (Map (ThreadId m) (Async m ())) -> STM m (Map (ThreadId m) (Async m ())) forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> stm a readTVar TVar_ (STM m) (Map (ThreadId m) (Async m ())) jobsVar) (Async m () -> m ()) -> Map (ThreadId m) (Async m ()) -> m () forall (t :: * -> *) (m :: * -> *) a b. (Foldable t, Monad m) => (a -> m b) -> t a -> m () mapM_ Async m () -> m () forall (m :: * -> *) a. MonadAsync m => Async m a -> m () cancel Map (ThreadId m) (Async m ()) jobs forkJob :: forall m a. (MonadAsync m, MonadMask m) => JobPool m a -> Job m a -> m () forkJob :: JobPool m a -> Job m a -> m () forkJob JobPool{TVar m (Map (ThreadId m) (Async m ())) jobsVar :: TVar m (Map (ThreadId m) (Async m ())) jobsVar :: forall (m :: * -> *) a. JobPool m a -> TVar m (Map (ThreadId m) (Async m ())) jobsVar, TQueue m a completionQueue :: TQueue m a completionQueue :: forall (m :: * -> *) a. JobPool m a -> TQueue m a completionQueue} (Job m a action SomeException -> a handler String label) = ((forall a. m a -> m a) -> m ()) -> m () forall (m :: * -> *) b. MonadMask m => ((forall a. m a -> m a) -> m b) -> m b mask (((forall a. m a -> m a) -> m ()) -> m ()) -> ((forall a. m a -> m a) -> m ()) -> m () forall a b. (a -> b) -> a -> b $ \forall a. m a -> m a restore -> do Async m () jobAsync <- m () -> m (Async m ()) forall (m :: * -> *) a. MonadAsync m => m a -> m (Async m a) async (m () -> m (Async m ())) -> m () -> m (Async m ()) forall a b. (a -> b) -> a -> b $ do ThreadId m tid <- m (ThreadId m) forall (m :: * -> *). MonadThread m => m (ThreadId m) myThreadId ThreadId m -> String -> m () forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m () labelThread ThreadId m tid String label !a res <- (SomeException -> Maybe SomeException) -> (SomeException -> m a) -> m a -> m a forall (m :: * -> *) e b a. (MonadCatch m, Exception e) => (e -> Maybe b) -> (b -> m a) -> m a -> m a handleJust SomeException -> Maybe SomeException notAsyncExceptions (a -> m a forall (m :: * -> *) a. Monad m => a -> m a return (a -> m a) -> (SomeException -> a) -> SomeException -> m a forall b c a. (b -> c) -> (a -> b) -> a -> c . SomeException -> a handler) (m a -> m a) -> m a -> m a forall a b. (a -> b) -> a -> b $ m a -> m a forall a. m a -> m a restore m a action STM m () -> m () forall (m :: * -> *) a. (MonadSTM m, HasCallStack) => STM m a -> m a atomically (STM m () -> m ()) -> STM m () -> m () forall a b. (a -> b) -> a -> b $ do TQueue m a -> a -> STM m () forall (stm :: * -> *) a. MonadSTMTx stm => TQueue_ stm a -> a -> stm () writeTQueue TQueue m a completionQueue a res TVar m (Map (ThreadId m) (Async m ())) -> (Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ())) -> STM m () forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> (a -> a) -> stm () modifyTVar' TVar m (Map (ThreadId m) (Async m ())) jobsVar (ThreadId m -> Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ()) forall k a. Ord k => k -> Map k a -> Map k a Map.delete ThreadId m tid) let !tid :: ThreadId m tid = Proxy m -> Async m () -> ThreadId m forall (m :: * -> *) a. MonadAsync m => Proxy m -> Async m a -> ThreadId m asyncThreadId (Proxy m forall k (t :: k). Proxy t Proxy :: Proxy m) Async m () jobAsync STM m () -> m () forall (m :: * -> *) a. (MonadSTM m, HasCallStack) => STM m a -> m a atomically (STM m () -> m ()) -> STM m () -> m () forall a b. (a -> b) -> a -> b $ TVar m (Map (ThreadId m) (Async m ())) -> (Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ())) -> STM m () forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> (a -> a) -> stm () modifyTVar' TVar m (Map (ThreadId m) (Async m ())) jobsVar (ThreadId m -> Async m () -> Map (ThreadId m) (Async m ()) -> Map (ThreadId m) (Async m ()) forall k a. Ord k => k -> a -> Map k a -> Map k a Map.insert ThreadId m tid Async m () jobAsync) where notAsyncExceptions :: SomeException -> Maybe SomeException notAsyncExceptions :: SomeException -> Maybe SomeException notAsyncExceptions SomeException e | Just (SomeAsyncException e _) <- SomeException -> Maybe SomeAsyncException forall e. Exception e => SomeException -> Maybe e fromException SomeException e = Maybe SomeException forall a. Maybe a Nothing | Bool otherwise = SomeException -> Maybe SomeException forall a. a -> Maybe a Just SomeException e readSize :: MonadSTM m => JobPool m a -> STM m Int readSize :: JobPool m a -> STM m Int readSize JobPool{TVar m (Map (ThreadId m) (Async m ())) jobsVar :: TVar m (Map (ThreadId m) (Async m ())) jobsVar :: forall (m :: * -> *) a. JobPool m a -> TVar m (Map (ThreadId m) (Async m ())) jobsVar} = Map (ThreadId m) (Async m ()) -> Int forall k a. Map k a -> Int Map.size (Map (ThreadId m) (Async m ()) -> Int) -> STM m (Map (ThreadId m) (Async m ())) -> STM m Int forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> TVar m (Map (ThreadId m) (Async m ())) -> STM m (Map (ThreadId m) (Async m ())) forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> stm a readTVar TVar m (Map (ThreadId m) (Async m ())) jobsVar collect :: MonadSTM m => JobPool m a -> STM m a collect :: JobPool m a -> STM m a collect JobPool{TQueue m a completionQueue :: TQueue m a completionQueue :: forall (m :: * -> *) a. JobPool m a -> TQueue m a completionQueue} = TQueue m a -> STM m a forall (stm :: * -> *) a. MonadSTMTx stm => TQueue_ stm a -> stm a readTQueue TQueue m a completionQueue