{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeFamilies #-}
module Control.Monad.Class.MonadTimer (
MonadDelay(..)
, MonadTimer(..)
, TimeoutState(..)
, DiffTime
, diffTimeToMicrosecondsAsInt
, microsecondsAsIntToDiffTime
) where
import qualified Control.Concurrent as IO
import qualified Control.Concurrent.STM.TVar as STM
import Control.Exception (assert)
import Control.Monad.Reader
import qualified Control.Monad.STM as STM
import Data.Kind (Type)
import Data.Time.Clock (DiffTime, diffTimeToPicoseconds)
#if defined(__GLASGOW_HASKELL__) && !defined(mingw32_HOST_OS) && !defined(__GHCJS__)
import qualified GHC.Event as GHC (TimeoutKey, getSystemTimerManager,
registerTimeout, unregisterTimeout, updateTimeout)
#endif
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSTM
import qualified System.Timeout as IO
data TimeoutState = TimeoutPending | TimeoutFired | TimeoutCancelled
class Monad m => MonadDelay m where
threadDelay :: DiffTime -> m ()
default threadDelay :: MonadTimer m => DiffTime -> m ()
threadDelay DiffTime
d = m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> (Timeout m -> m Bool) -> Timeout m -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM m Bool -> m Bool
forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically (STM m Bool -> m Bool)
-> (Timeout m -> STM m Bool) -> Timeout m -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Timeout m -> STM m Bool
forall (m :: * -> *). MonadTimer m => Timeout m -> STM m Bool
awaitTimeout (Timeout m -> m ()) -> m (Timeout m) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DiffTime -> m (Timeout m)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (Timeout m)
newTimeout DiffTime
d
class (MonadSTM m, MonadDelay m) => MonadTimer m where
data Timeout m :: Type
newTimeout :: DiffTime -> m (Timeout m)
readTimeout :: Timeout m -> STM m TimeoutState
updateTimeout :: Timeout m -> DiffTime -> m ()
cancelTimeout :: Timeout m -> m ()
awaitTimeout :: Timeout m -> STM m Bool
awaitTimeout Timeout m
t = do TimeoutState
s <- Timeout m -> STM m TimeoutState
forall (m :: * -> *).
MonadTimer m =>
Timeout m -> STM m TimeoutState
readTimeout Timeout m
t
case TimeoutState
s of
TimeoutState
TimeoutPending -> STM m Bool
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry
TimeoutState
TimeoutFired -> Bool -> STM m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
TimeoutState
TimeoutCancelled -> Bool -> STM m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
registerDelay :: DiffTime -> m (TVar m Bool)
default registerDelay :: MonadFork m => DiffTime -> m (TVar m Bool)
registerDelay = DiffTime -> m (TVar m Bool)
forall (m :: * -> *).
(MonadTimer m, MonadFork m) =>
DiffTime -> m (TVar m Bool)
defaultRegisterDelay
timeout :: DiffTime -> m a -> m (Maybe a)
defaultRegisterDelay :: ( MonadTimer m
, MonadFork m
)
=> DiffTime
-> m (TVar m Bool)
defaultRegisterDelay :: forall (m :: * -> *).
(MonadTimer m, MonadFork m) =>
DiffTime -> m (TVar m Bool)
defaultRegisterDelay DiffTime
d = do
TVar m Bool
v <- STM m (TVar m Bool) -> m (TVar m Bool)
forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically (STM m (TVar m Bool) -> m (TVar m Bool))
-> STM m (TVar m Bool) -> m (TVar m Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> STM m (TVar m Bool)
forall (stm :: * -> *) a. MonadSTMTx stm => a -> stm (TVar_ stm a)
newTVar Bool
False
Timeout m
t <- DiffTime -> m (Timeout m)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (Timeout m)
newTimeout DiffTime
d
ThreadId m
_ <- m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (m () -> m (ThreadId m)) -> m () -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, ?callStack::CallStack) =>
STM m a -> m a
atomically (Timeout m -> STM m Bool
forall (m :: * -> *). MonadTimer m => Timeout m -> STM m Bool
awaitTimeout Timeout m
t STM m Bool -> (Bool -> STM m ()) -> STM m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TVar m Bool -> Bool -> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TVar_ stm a -> a -> stm ()
writeTVar TVar m Bool
v)
TVar m Bool -> m (TVar m Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return TVar m Bool
v
instance MonadDelay IO where
threadDelay :: DiffTime -> IO ()
threadDelay = DiffTime -> IO ()
go
where
go :: DiffTime -> IO ()
go :: DiffTime -> IO ()
go DiffTime
d | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> DiffTime
maxDelay = do
Int -> IO ()
IO.threadDelay Int
forall a. Bounded a => a
maxBound
DiffTime -> IO ()
go (DiffTime
d DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
- DiffTime
maxDelay)
go DiffTime
d = do
Int -> IO ()
IO.threadDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
maxDelay :: DiffTime
maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a
maxBound
#if defined(__GLASGOW_HASKELL__) && !defined(mingw32_HOST_OS) && !defined(__GHCJS__)
instance MonadTimer IO where
data Timeout IO = TimeoutIO !(STM.TVar TimeoutState) !GHC.TimeoutKey
readTimeout :: Timeout IO -> STM IO TimeoutState
readTimeout (TimeoutIO TVar TimeoutState
var TimeoutKey
_key) = TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
newTimeout :: DiffTime -> IO (Timeout IO)
newTimeout = \DiffTime
d -> do
TVar TimeoutState
var <- TimeoutState -> IO (TVar TimeoutState)
forall a. a -> IO (TVar a)
STM.newTVarIO TimeoutState
TimeoutPending
TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
TimeoutKey
key <- TimerManager -> Int -> IO () -> IO TimeoutKey
GHC.registerTimeout TimerManager
mgr (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
(STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var))
Timeout IO -> IO (Timeout IO)
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar TimeoutState -> TimeoutKey -> Timeout IO
TimeoutIO TVar TimeoutState
var TimeoutKey
key)
where
timeoutAction :: TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var = do
TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
case TimeoutState
x of
TimeoutState
TimeoutPending -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutFired
TimeoutState
TimeoutFired -> [Char] -> STM ()
forall a. (?callStack::CallStack) => [Char] -> a
error [Char]
"MonadTimer(IO): invariant violation"
TimeoutState
TimeoutCancelled -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
updateTimeout :: Timeout IO -> DiffTime -> IO ()
updateTimeout (TimeoutIO TVar TimeoutState
_var TimeoutKey
key) DiffTime
d = do
TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
TimerManager -> TimeoutKey -> Int -> IO ()
GHC.updateTimeout TimerManager
mgr TimeoutKey
key (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
cancelTimeout :: Timeout IO -> IO ()
cancelTimeout (TimeoutIO TVar TimeoutState
var TimeoutKey
key) = do
STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
case TimeoutState
x of
TimeoutState
TimeoutPending -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutCancelled
TimeoutState
TimeoutFired -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TimeoutState
TimeoutCancelled -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
TimerManager -> TimeoutKey -> IO ()
GHC.unregisterTimeout TimerManager
mgr TimeoutKey
key
#else
instance MonadTimer IO where
data Timeout IO = TimeoutIO !(STM.TVar (STM.TVar Bool)) !(STM.TVar Bool)
readTimeout (TimeoutIO timeoutvarvar cancelvar) = do
canceled <- STM.readTVar cancelvar
fired <- STM.readTVar =<< STM.readTVar timeoutvarvar
case (canceled, fired) of
(True, _) -> return TimeoutCancelled
(_, False) -> return TimeoutPending
(_, True) -> return TimeoutFired
newTimeout d = do
timeoutvar <- STM.registerDelay (diffTimeToMicrosecondsAsInt d)
timeoutvarvar <- STM.newTVarIO timeoutvar
cancelvar <- STM.newTVarIO False
return (TimeoutIO timeoutvarvar cancelvar)
updateTimeout (TimeoutIO timeoutvarvar _cancelvar) d = do
timeoutvar' <- STM.registerDelay (diffTimeToMicrosecondsAsInt d)
STM.atomically $ STM.writeTVar timeoutvarvar timeoutvar'
cancelTimeout (TimeoutIO timeoutvarvar cancelvar) =
STM.atomically $ do
fired <- STM.readTVar =<< STM.readTVar timeoutvarvar
when (not fired) $ STM.writeTVar cancelvar True
#endif
registerDelay :: DiffTime -> IO (TVar IO Bool)
registerDelay DiffTime
d
| DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay =
Int -> IO (TVar Bool)
STM.registerDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
| Bool
otherwise =
DiffTime -> IO (TVar IO Bool)
forall (m :: * -> *).
(MonadTimer m, MonadFork m) =>
DiffTime -> m (TVar m Bool)
defaultRegisterDelay DiffTime
d
where
maxDelay :: DiffTime
maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a
maxBound
timeout :: forall a. DiffTime -> IO a -> IO (Maybe a)
timeout = Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
IO.timeout (Int -> IO a -> IO (Maybe a))
-> (DiffTime -> Int) -> DiffTime -> IO a -> IO (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Int
diffTimeToMicrosecondsAsInt
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d =
let usec :: Integer
usec :: Integer
usec = DiffTime -> Integer
diffTimeToPicoseconds DiffTime
d Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
1_000_000 in
Bool -> Int -> Int
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Integer
usec Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int)) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$
Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
usec
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime = (DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ DiffTime
1_000_000) (DiffTime -> DiffTime) -> (Int -> DiffTime) -> Int -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral
instance MonadDelay m => MonadDelay (ReaderT r m) where
threadDelay :: DiffTime -> ReaderT r m ()
threadDelay = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (DiffTime -> m ()) -> DiffTime -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay
instance (MonadTimer m, MonadFork m) => MonadTimer (ReaderT r m) where
newtype Timeout (ReaderT r m) = WrapTimeoutReader {
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader :: Timeout m
}
newTimeout :: DiffTime -> ReaderT r m (Timeout (ReaderT r m))
newTimeout DiffTime
d = m (Timeout (ReaderT r m)) -> ReaderT r m (Timeout (ReaderT r m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Timeout (ReaderT r m)) -> ReaderT r m (Timeout (ReaderT r m)))
-> m (Timeout (ReaderT r m)) -> ReaderT r m (Timeout (ReaderT r m))
forall a b. (a -> b) -> a -> b
$ Timeout m -> Timeout (ReaderT r m)
forall r (m :: * -> *). Timeout m -> Timeout (ReaderT r m)
WrapTimeoutReader (Timeout m -> Timeout (ReaderT r m))
-> m (Timeout m) -> m (Timeout (ReaderT r m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DiffTime -> m (Timeout m)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (Timeout m)
newTimeout DiffTime
d
updateTimeout :: Timeout (ReaderT r m) -> DiffTime -> ReaderT r m ()
updateTimeout Timeout (ReaderT r m)
t = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (DiffTime -> m ()) -> DiffTime -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Timeout m -> DiffTime -> m ()
forall (m :: * -> *). MonadTimer m => Timeout m -> DiffTime -> m ()
updateTimeout (Timeout (ReaderT r m) -> Timeout m
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader Timeout (ReaderT r m)
t)
cancelTimeout :: Timeout (ReaderT r m) -> ReaderT r m ()
cancelTimeout Timeout (ReaderT r m)
t = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ()) -> m () -> ReaderT r m ()
forall a b. (a -> b) -> a -> b
$ Timeout m -> m ()
forall (m :: * -> *). MonadTimer m => Timeout m -> m ()
cancelTimeout (Timeout (ReaderT r m) -> Timeout m
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader Timeout (ReaderT r m)
t)
timeout :: forall a. DiffTime -> ReaderT r m a -> ReaderT r m (Maybe a)
timeout DiffTime
d ReaderT r m a
ma = (r -> m (Maybe a)) -> ReaderT r m (Maybe a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Maybe a)) -> ReaderT r m (Maybe a))
-> (r -> m (Maybe a)) -> ReaderT r m (Maybe a)
forall a b. (a -> b) -> a -> b
$ DiffTime -> m a -> m (Maybe a)
forall (m :: * -> *) a.
MonadTimer m =>
DiffTime -> m a -> m (Maybe a)
timeout DiffTime
d (m a -> m (Maybe a)) -> (r -> m a) -> r -> m (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r m a
ma
readTimeout :: Timeout (ReaderT r m) -> STM (ReaderT r m) TimeoutState
readTimeout Timeout (ReaderT r m)
t = Timeout m -> STM m TimeoutState
forall (m :: * -> *).
MonadTimer m =>
Timeout m -> STM m TimeoutState
readTimeout (Timeout m -> STM m TimeoutState)
-> Timeout m -> STM m TimeoutState
forall a b. (a -> b) -> a -> b
$ Timeout (ReaderT r m) -> Timeout m
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader Timeout (ReaderT r m)
t