{-# LANGUAGE CPP                #-}
{-# LANGUAGE DefaultSignatures  #-}
{-# LANGUAGE FlexibleContexts   #-}
{-# LANGUAGE FlexibleInstances  #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeFamilies       #-}

module Control.Monad.Class.MonadTimer (
    MonadDelay(..)
  , MonadTimer(..)
  , TimeoutState(..)

  , DiffTime
  , diffTimeToMicrosecondsAsInt
  , microsecondsAsIntToDiffTime
  ) where

import qualified Control.Concurrent as IO
import qualified Control.Concurrent.STM.TVar as STM
import           Control.Exception (assert)
import           Control.Monad.Reader
import qualified Control.Monad.STM as STM
import           Data.Kind (Type)
import           Data.Time.Clock (DiffTime, diffTimeToPicoseconds)

#if defined(__GLASGOW_HASKELL__) && !defined(mingw32_HOST_OS) && !defined(__GHCJS__)
import qualified GHC.Event as GHC (TimeoutKey, getSystemTimerManager,
                     registerTimeout, unregisterTimeout, updateTimeout)
#endif

import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadSTM

import qualified System.Timeout as IO

data TimeoutState = TimeoutPending | TimeoutFired | TimeoutCancelled

class Monad m => MonadDelay m where
  threadDelay :: DiffTime -> m ()

  default threadDelay :: MonadTimer m => DiffTime -> m ()
  threadDelay DiffTime
d   = m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> (Timeout m -> m Bool) -> Timeout m -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM m Bool -> m Bool
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m Bool -> m Bool)
-> (Timeout m -> STM m Bool) -> Timeout m -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Timeout m -> STM m Bool
forall (m :: * -> *). MonadTimer m => Timeout m -> STM m Bool
awaitTimeout (Timeout m -> m ()) -> m (Timeout m) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DiffTime -> m (Timeout m)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (Timeout m)
newTimeout DiffTime
d

class (MonadSTM m, MonadDelay m) => MonadTimer m where
  data Timeout m :: Type

  -- | Create a new timeout which will fire at the given time duration in
  -- the future.
  --
  -- The timeout will start in the 'TimeoutPending' state and either
  -- fire at or after the given time leaving it in the 'TimeoutFired' state,
  -- or it may be cancelled with 'cancelTimeout', leaving it in the
  -- 'TimeoutCancelled' state.
  --
  -- Timeouts /cannot/ be reset to the pending state once fired or cancelled
  -- (as this would be very racy). You should create a new timeout if you need
  -- this functionality.
  --
  newTimeout     :: DiffTime -> m (Timeout m)

  -- | Read the current state of a timeout. This does not block, but returns
  -- the current state. It is your responsibility to use 'retry' to wait.
  --
  -- Alternatively you may wish to use the convenience utility 'awaitTimeout'
  -- to wait for just the fired or cancelled outcomes.
  --
  -- You should consider the cancelled state if you plan to use 'cancelTimeout'.
  --
  readTimeout    :: Timeout m -> STM m TimeoutState

  -- Adjust when this timer will fire, to the given duration into the future.
  --
  -- It is safe to race this concurrently against the timer firing. It will
  -- have no effect if the timer fires first.
  --
  -- The new time can be before or after the original expiry time, though
  -- arguably it is an application design flaw to move timeouts sooner.
  --
  updateTimeout  :: Timeout m -> DiffTime -> m ()

  -- | Cancel a timeout (unless it has already fired), putting it into the
  -- 'TimeoutCancelled' state. Code reading and acting on the timeout state
  -- need to handle such cancellation appropriately.
  --
  -- It is safe to race this concurrently against the timer firing. It will
  -- have no effect if the timer fires first.
  --
  cancelTimeout  :: Timeout m -> m ()

  -- | Returns @True@ when the timeout is fired, or @False@ if it is cancelled.
  awaitTimeout   :: Timeout m -> STM m Bool
  awaitTimeout Timeout m
t  = do TimeoutState
s <- Timeout m -> STM m TimeoutState
forall (m :: * -> *).
MonadTimer m =>
Timeout m -> STM m TimeoutState
readTimeout Timeout m
t
                       case TimeoutState
s of
                         TimeoutState
TimeoutPending   -> STM m Bool
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry
                         TimeoutState
TimeoutFired     -> Bool -> STM m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                         TimeoutState
TimeoutCancelled -> Bool -> STM m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

  registerDelay :: DiffTime -> m (TVar m Bool)

  default registerDelay :: MonadFork m => DiffTime -> m (TVar m Bool)
  registerDelay = DiffTime -> m (TVar m Bool)
forall (m :: * -> *).
(MonadTimer m, MonadFork m) =>
DiffTime -> m (TVar m Bool)
defaultRegisterDelay

  timeout :: DiffTime -> m a -> m (Maybe a)


defaultRegisterDelay :: ( MonadTimer m
                        , MonadFork  m
                        )
                     => DiffTime
                     -> m (TVar m Bool)
defaultRegisterDelay :: DiffTime -> m (TVar m Bool)
defaultRegisterDelay DiffTime
d = do
    TVar m Bool
v <- STM m (TVar m Bool) -> m (TVar m Bool)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (TVar m Bool) -> m (TVar m Bool))
-> STM m (TVar m Bool) -> m (TVar m Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> STM m (TVar m Bool)
forall (stm :: * -> *) a. MonadSTMTx stm => a -> stm (TVar_ stm a)
newTVar Bool
False
    Timeout m
t <- DiffTime -> m (Timeout m)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (Timeout m)
newTimeout DiffTime
d
    ThreadId m
_ <- m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (m () -> m (ThreadId m)) -> m () -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (Timeout m -> STM m Bool
forall (m :: * -> *). MonadTimer m => Timeout m -> STM m Bool
awaitTimeout Timeout m
t STM m Bool -> (Bool -> STM m ()) -> STM m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TVar m Bool -> Bool -> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TVar_ stm a -> a -> stm ()
writeTVar TVar m Bool
v)
    TVar m Bool -> m (TVar m Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return TVar m Bool
v

--
-- Instances for IO
--

-- | With 'threadDelay' one can use arbitrary large 'DiffTime's, which is an
-- advantage over 'IO.threadDelay'.
--
instance MonadDelay IO where
  threadDelay :: DiffTime -> IO ()
threadDelay = DiffTime -> IO ()
go
    where
      go :: DiffTime -> IO ()
      go :: DiffTime -> IO ()
go DiffTime
d | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> DiffTime
maxDelay = do
        Int -> IO ()
IO.threadDelay Int
forall a. Bounded a => a
maxBound
        DiffTime -> IO ()
go (DiffTime
d DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
- DiffTime
maxDelay)
      go DiffTime
d = do
        Int -> IO ()
IO.threadDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)

      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a
maxBound

#if defined(__GLASGOW_HASKELL__) && !defined(mingw32_HOST_OS) && !defined(__GHCJS__)
instance MonadTimer IO where
  data Timeout IO = TimeoutIO !(STM.TVar TimeoutState) !GHC.TimeoutKey

  readTimeout :: Timeout IO -> STM IO TimeoutState
readTimeout (TimeoutIO var _key) = TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var

  newTimeout :: DiffTime -> IO (Timeout IO)
newTimeout = \DiffTime
d -> do
      TVar TimeoutState
var <- TimeoutState -> IO (TVar TimeoutState)
forall a. a -> IO (TVar a)
STM.newTVarIO TimeoutState
TimeoutPending
      TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
      TimeoutKey
key <- TimerManager -> Int -> IO () -> IO TimeoutKey
GHC.registerTimeout TimerManager
mgr (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
                                     (STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var))
      Timeout IO -> IO (Timeout IO)
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar TimeoutState -> TimeoutKey -> Timeout IO
TimeoutIO TVar TimeoutState
var TimeoutKey
key)
    where
      timeoutAction :: TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var = do
        TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
        case TimeoutState
x of
          TimeoutState
TimeoutPending   -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutFired
          TimeoutState
TimeoutFired     -> [Char] -> STM ()
forall a. HasCallStack => [Char] -> a
error [Char]
"MonadTimer(IO): invariant violation"
          TimeoutState
TimeoutCancelled -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  -- In GHC's TimerManager this has no effect if the timer already fired.
  -- It is safe to race against the timer firing.
  updateTimeout :: Timeout IO -> DiffTime -> IO ()
updateTimeout (TimeoutIO _var key) DiffTime
d = do
      TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
      TimerManager -> TimeoutKey -> Int -> IO ()
GHC.updateTimeout TimerManager
mgr TimeoutKey
key (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)

  cancelTimeout :: Timeout IO -> IO ()
cancelTimeout (TimeoutIO var key) = do
      STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
        case TimeoutState
x of
          TimeoutState
TimeoutPending   -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutCancelled
          TimeoutState
TimeoutFired     -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          TimeoutState
TimeoutCancelled -> () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
      TimerManager -> TimeoutKey -> IO ()
GHC.unregisterTimeout TimerManager
mgr TimeoutKey
key
#else
instance MonadTimer IO where
  data Timeout IO = TimeoutIO !(STM.TVar (STM.TVar Bool)) !(STM.TVar Bool)

  readTimeout (TimeoutIO timeoutvarvar cancelvar) = do
    canceled <- STM.readTVar cancelvar
    fired    <- STM.readTVar =<< STM.readTVar timeoutvarvar
    case (canceled, fired) of
      (True, _)  -> return TimeoutCancelled
      (_, False) -> return TimeoutPending
      (_, True)  -> return TimeoutFired

  newTimeout d = do
    timeoutvar    <- STM.registerDelay (diffTimeToMicrosecondsAsInt d)
    timeoutvarvar <- STM.newTVarIO timeoutvar
    cancelvar     <- STM.newTVarIO False
    return (TimeoutIO timeoutvarvar cancelvar)

  updateTimeout (TimeoutIO timeoutvarvar _cancelvar) d = do
    timeoutvar' <- STM.registerDelay (diffTimeToMicrosecondsAsInt d)
    STM.atomically $ STM.writeTVar timeoutvarvar timeoutvar'

  cancelTimeout (TimeoutIO timeoutvarvar cancelvar) =
    STM.atomically $ do
      fired <- STM.readTVar =<< STM.readTVar timeoutvarvar
      when (not fired) $ STM.writeTVar cancelvar True
#endif

  -- | For delays less (or equal) than @maxBound :: Int@ this is exactly the same as
  -- 'STM.registerDaley'; for larger delays it will start a monitoring thread
  -- whcih will update the 'TVar'.
  --
  -- TODO: issue #2184 'registerDelay' relies on 'newTimeout', through
  -- 'defaultRegisterDelay'.  'newTimeout' can overflow an 'Int' (this is
  -- especially easy on 32-bit architectures).
  registerDelay :: DiffTime -> IO (TVar IO Bool)
registerDelay DiffTime
d
      | DiffTime
d DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
<= DiffTime
maxDelay =
        Int -> IO (TVar Bool)
STM.registerDelay (DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d)
      | Bool
otherwise =
        DiffTime -> IO (TVar IO Bool)
forall (m :: * -> *).
(MonadTimer m, MonadFork m) =>
DiffTime -> m (TVar m Bool)
defaultRegisterDelay DiffTime
d
    where
      maxDelay :: DiffTime
      maxDelay :: DiffTime
maxDelay = Int -> DiffTime
microsecondsAsIntToDiffTime Int
forall a. Bounded a => a
maxBound

  timeout :: DiffTime -> IO a -> IO (Maybe a)
timeout = Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
IO.timeout (Int -> IO a -> IO (Maybe a))
-> (DiffTime -> Int) -> DiffTime -> IO a -> IO (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Int
diffTimeToMicrosecondsAsInt


diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt :: DiffTime -> Int
diffTimeToMicrosecondsAsInt DiffTime
d =
    let usec :: Integer
        usec :: Integer
usec = DiffTime -> Integer
diffTimeToPicoseconds DiffTime
d Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
1_000_000 in
    -- Can only represent usec times that fit within an Int, which on 32bit
    -- systems means 2^31 usec, which is only ~35 minutes.
    Bool -> Int -> Int
forall a. HasCallStack => Bool -> a -> a
assert (Integer
usec Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int)) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$
    Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
usec

microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime :: Int -> DiffTime
microsecondsAsIntToDiffTime = (DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ DiffTime
1_000_000) (DiffTime -> DiffTime) -> (Int -> DiffTime) -> Int -> DiffTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> DiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral

--
-- Lift to ReaderT
--

instance MonadDelay m => MonadDelay (ReaderT r m) where
  threadDelay :: DiffTime -> ReaderT r m ()
threadDelay = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (DiffTime -> m ()) -> DiffTime -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay

instance (MonadTimer m, MonadFork m) => MonadTimer (ReaderT r m) where
  newtype Timeout (ReaderT r m) = WrapTimeoutReader {
      Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader :: Timeout m
    }

  newTimeout :: DiffTime -> ReaderT r m (Timeout (ReaderT r m))
newTimeout    DiffTime
d = m (Timeout (ReaderT r m)) -> ReaderT r m (Timeout (ReaderT r m))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Timeout (ReaderT r m)) -> ReaderT r m (Timeout (ReaderT r m)))
-> m (Timeout (ReaderT r m)) -> ReaderT r m (Timeout (ReaderT r m))
forall a b. (a -> b) -> a -> b
$ Timeout m -> Timeout (ReaderT r m)
forall r (m :: * -> *). Timeout m -> Timeout (ReaderT r m)
WrapTimeoutReader (Timeout m -> Timeout (ReaderT r m))
-> m (Timeout m) -> m (Timeout (ReaderT r m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DiffTime -> m (Timeout m)
forall (m :: * -> *). MonadTimer m => DiffTime -> m (Timeout m)
newTimeout DiffTime
d
  updateTimeout :: Timeout (ReaderT r m) -> DiffTime -> ReaderT r m ()
updateTimeout Timeout (ReaderT r m)
t = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (DiffTime -> m ()) -> DiffTime -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Timeout m -> DiffTime -> m ()
forall (m :: * -> *). MonadTimer m => Timeout m -> DiffTime -> m ()
updateTimeout (Timeout (ReaderT r m) -> Timeout m
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader Timeout (ReaderT r m)
t)
  cancelTimeout :: Timeout (ReaderT r m) -> ReaderT r m ()
cancelTimeout Timeout (ReaderT r m)
t = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ()) -> m () -> ReaderT r m ()
forall a b. (a -> b) -> a -> b
$ Timeout m -> m ()
forall (m :: * -> *). MonadTimer m => Timeout m -> m ()
cancelTimeout (Timeout (ReaderT r m) -> Timeout m
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader Timeout (ReaderT r m)
t)

  timeout :: DiffTime -> ReaderT r m a -> ReaderT r m (Maybe a)
timeout     DiffTime
d ReaderT r m a
ma = (r -> m (Maybe a)) -> ReaderT r m (Maybe a)
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m (Maybe a)) -> ReaderT r m (Maybe a))
-> (r -> m (Maybe a)) -> ReaderT r m (Maybe a)
forall a b. (a -> b) -> a -> b
$ DiffTime -> m a -> m (Maybe a)
forall (m :: * -> *) a.
MonadTimer m =>
DiffTime -> m a -> m (Maybe a)
timeout DiffTime
d (m a -> m (Maybe a)) -> (r -> m a) -> r -> m (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r m a
ma
  readTimeout :: Timeout (ReaderT r m) -> STM (ReaderT r m) TimeoutState
readTimeout Timeout (ReaderT r m)
t    = Timeout m -> STM m TimeoutState
forall (m :: * -> *).
MonadTimer m =>
Timeout m -> STM m TimeoutState
readTimeout (Timeout m -> STM m TimeoutState)
-> Timeout m -> STM m TimeoutState
forall a b. (a -> b) -> a -> b
$ Timeout (ReaderT r m) -> Timeout m
forall r (m :: * -> *). Timeout (ReaderT r m) -> Timeout m
unwrapTimeoutReader Timeout (ReaderT r m)
t