{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | A thread manager.
--   The manager has responsibility to kill worker threads.
module Network.HTTP2.H2.Manager (
    Manager,
    start,
    stopAfter,
    forkManaged,
    forkManagedUnmask,
    withTimeout,
    KilledByHttp2ThreadManager (..),
    waitCounter0,
) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import qualified Control.Exception as E
import Data.Foldable
import Data.Map (Map)
import qualified Data.Map.Strict as Map
import qualified System.TimeManager as T

import Imports

----------------------------------------------------------------

-- | Manager to manage the thread and the timer.
data Manager = Manager T.Manager (TVar ManagedThreads)

type ManagedThreads = Map ThreadId TimeoutHandle

----------------------------------------------------------------

data TimeoutHandle
    = ThreadWithTimeout T.Handle
    | ThreadWithoutTimeout

cancelTimeout :: TimeoutHandle -> IO ()
cancelTimeout :: TimeoutHandle -> IO ()
cancelTimeout (ThreadWithTimeout Handle
th) = Handle -> IO ()
T.cancel Handle
th
cancelTimeout TimeoutHandle
ThreadWithoutTimeout = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

----------------------------------------------------------------

-- | Starting a thread manager.
--   Its action is initially set to 'return ()' and should be set
--   by 'setAction'. This allows that the action can include
--   the manager itself.
start :: T.Manager -> IO Manager
start :: Manager -> IO Manager
start Manager
timmgr = Manager -> TVar ManagedThreads -> Manager
Manager Manager
timmgr (TVar ManagedThreads -> Manager)
-> IO (TVar ManagedThreads) -> IO Manager
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ManagedThreads -> IO (TVar ManagedThreads)
forall a. a -> IO (TVar a)
newTVarIO ManagedThreads
forall k a. Map k a
Map.empty

----------------------------------------------------------------

data KilledByHttp2ThreadManager = KilledByHttp2ThreadManager (Maybe SomeException)
    deriving (Int -> KilledByHttp2ThreadManager -> ShowS
[KilledByHttp2ThreadManager] -> ShowS
KilledByHttp2ThreadManager -> String
(Int -> KilledByHttp2ThreadManager -> ShowS)
-> (KilledByHttp2ThreadManager -> String)
-> ([KilledByHttp2ThreadManager] -> ShowS)
-> Show KilledByHttp2ThreadManager
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KilledByHttp2ThreadManager -> ShowS
showsPrec :: Int -> KilledByHttp2ThreadManager -> ShowS
$cshow :: KilledByHttp2ThreadManager -> String
show :: KilledByHttp2ThreadManager -> String
$cshowList :: [KilledByHttp2ThreadManager] -> ShowS
showList :: [KilledByHttp2ThreadManager] -> ShowS
Show)

instance Exception KilledByHttp2ThreadManager where
    toException :: KilledByHttp2ThreadManager -> SomeException
toException = KilledByHttp2ThreadManager -> SomeException
forall e. Exception e => e -> SomeException
asyncExceptionToException
    fromException :: SomeException -> Maybe KilledByHttp2ThreadManager
fromException = SomeException -> Maybe KilledByHttp2ThreadManager
forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException

-- | Stopping the manager.
--
-- The action is run in the scope of an exception handler that catches all
-- exceptions (including asynchronous ones); this allows the cleanup handler
-- to cleanup in all circumstances. If an exception is caught, it is rethrown
-- after the cleanup is complete.
stopAfter :: Manager -> IO a -> (Maybe SomeException -> IO ()) -> IO a
stopAfter :: forall a. Manager -> IO a -> (Maybe SomeException -> IO ()) -> IO a
stopAfter (Manager Manager
_timmgr TVar ManagedThreads
var) IO a
action Maybe SomeException -> IO ()
cleanup = do
    ((forall a. IO a -> IO a) -> IO a) -> IO a
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO a) -> IO a)
-> ((forall a. IO a -> IO a) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> do
        Either SomeException a
ma <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO (Either SomeException a))
-> IO a -> IO (Either SomeException a)
forall a b. (a -> b) -> a -> b
$ IO a -> IO a
forall a. IO a -> IO a
unmask IO a
action
        ManagedThreads
m <- STM ManagedThreads -> IO ManagedThreads
forall a. STM a -> IO a
atomically (STM ManagedThreads -> IO ManagedThreads)
-> STM ManagedThreads -> IO ManagedThreads
forall a b. (a -> b) -> a -> b
$ do
            ManagedThreads
m0 <- TVar ManagedThreads -> STM ManagedThreads
forall a. TVar a -> STM a
readTVar TVar ManagedThreads
var
            TVar ManagedThreads -> ManagedThreads -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar ManagedThreads
var ManagedThreads
forall k a. Map k a
Map.empty
            ManagedThreads -> STM ManagedThreads
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ManagedThreads
m0
        [TimeoutHandle] -> (TimeoutHandle -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ManagedThreads -> [TimeoutHandle]
forall k a. Map k a -> [a]
Map.elems ManagedThreads
m) TimeoutHandle -> IO ()
cancelTimeout
        let er :: Maybe SomeException
er = (SomeException -> Maybe SomeException)
-> (a -> Maybe SomeException)
-> Either SomeException a
-> Maybe SomeException
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just (Maybe SomeException -> a -> Maybe SomeException
forall a b. a -> b -> a
const Maybe SomeException
forall a. Maybe a
Nothing) Either SomeException a
ma
        [ThreadId] -> (ThreadId -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ManagedThreads -> [ThreadId]
forall k a. Map k a -> [k]
Map.keys ManagedThreads
m) ((ThreadId -> IO ()) -> IO ()) -> (ThreadId -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ThreadId
tid ->
            ThreadId -> KilledByHttp2ThreadManager -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
E.throwTo ThreadId
tid (KilledByHttp2ThreadManager -> IO ())
-> KilledByHttp2ThreadManager -> IO ()
forall a b. (a -> b) -> a -> b
$ Maybe SomeException -> KilledByHttp2ThreadManager
KilledByHttp2ThreadManager Maybe SomeException
er
        case Either SomeException a
ma of
            Left SomeException
err -> Maybe SomeException -> IO ()
cleanup (SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
err) IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
err
            Right a
a -> Maybe SomeException -> IO ()
cleanup Maybe SomeException
forall a. Maybe a
Nothing IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

----------------------------------------------------------------

-- | Fork managed thread
--
-- This guarantees that the thread ID is added to the manager's queue before
-- the thread starts, and is removed again when the thread terminates
-- (normally or abnormally).
forkManaged :: Manager -> String -> IO () -> IO ()
forkManaged :: Manager -> String -> IO () -> IO ()
forkManaged Manager
mgr String
label IO ()
io =
    Manager -> String -> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forkManagedUnmask Manager
mgr String
label (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> IO () -> IO ()
forall a. IO a -> IO a
unmask IO ()
io

-- | Like 'forkManaged', but run action with exceptions masked
forkManagedUnmask
    :: Manager -> String -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkManagedUnmask :: Manager -> String -> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forkManagedUnmask (Manager Manager
_timmgr TVar ManagedThreads
var) String
label (forall a. IO a -> IO a) -> IO ()
io =
    IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle SomeException -> IO ()
forall {m :: * -> *}. Monad m => SomeException -> m ()
handler (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        String -> IO ()
labelMe String
label
        ThreadId
tid <- IO ThreadId
myThreadId
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ManagedThreads -> (ManagedThreads -> ManagedThreads) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar ManagedThreads
var ((ManagedThreads -> ManagedThreads) -> STM ())
-> (ManagedThreads -> ManagedThreads) -> STM ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> TimeoutHandle -> ManagedThreads -> ManagedThreads
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ThreadId
tid TimeoutHandle
ThreadWithoutTimeout
        -- We catch the exception and do not rethrow it: we don't want the
        -- exception printed to stderr.
        (forall a. IO a -> IO a) -> IO ()
io IO x -> IO x
forall a. IO a -> IO a
unmask IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(SomeException
_e :: SomeException) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ManagedThreads -> (ManagedThreads -> ManagedThreads) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar ManagedThreads
var ((ManagedThreads -> ManagedThreads) -> STM ())
-> (ManagedThreads -> ManagedThreads) -> STM ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> ManagedThreads -> ManagedThreads
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ThreadId
tid
  where
    handler :: SomeException -> m ()
handler (E.SomeException e
_) = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

waitCounter0 :: Manager -> IO ()
waitCounter0 :: Manager -> IO ()
waitCounter0 (Manager Manager
_timmgr TVar ManagedThreads
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    ManagedThreads
m <- TVar ManagedThreads -> STM ManagedThreads
forall a. TVar a -> STM a
readTVar TVar ManagedThreads
var
    Bool -> STM ()
check (ManagedThreads -> Int
forall k a. Map k a -> Int
Map.size ManagedThreads
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)

----------------------------------------------------------------

withTimeout :: Manager -> (T.Handle -> IO a) -> IO a
withTimeout :: forall a. Manager -> (Handle -> IO a) -> IO a
withTimeout (Manager Manager
timmgr TVar ManagedThreads
var) Handle -> IO a
action = do
    IO Handle -> (Handle -> IO ()) -> (Handle -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO Handle
register Handle -> IO ()
unregister ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
h ->
        Handle -> IO a
action Handle
h
  where
    register :: IO Handle
register = do
        ThreadId
tid <- IO ThreadId
myThreadId
        Handle
th <- Manager -> IO () -> IO Handle
T.registerKillThread Manager
timmgr (IO () -> IO Handle) -> IO () -> IO Handle
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        -- overriding ThreadWithoutTimeout
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar ManagedThreads -> (ManagedThreads -> ManagedThreads) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar ManagedThreads
var ((ManagedThreads -> ManagedThreads) -> STM ())
-> (ManagedThreads -> ManagedThreads) -> STM ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> TimeoutHandle -> ManagedThreads -> ManagedThreads
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ThreadId
tid (TimeoutHandle -> ManagedThreads -> ManagedThreads)
-> TimeoutHandle -> ManagedThreads -> ManagedThreads
forall a b. (a -> b) -> a -> b
$ Handle -> TimeoutHandle
ThreadWithTimeout Handle
th
        Handle -> IO Handle
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
th
    unregister :: Handle -> IO ()
unregister Handle
th = Handle -> IO ()
T.cancel Handle
th