{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE TypeFamilies          #-}

-- to preserve 'HasCallstack' constraint on 'checkInvariant'
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Control.Monad.Class.MonadSTM.Strict
  ( module X
  , LazyTVar
  , LazyTMVar
    -- * 'StrictTVar'
  , StrictTVar
  , castStrictTVar
  , toLazyTVar
  , newTVar
  , newTVarIO
  , newTVarWithInvariantIO
  , readTVar
  , writeTVar
  , modifyTVar
  , stateTVar
    -- * 'StrictTMVar'
  , StrictTMVar
  , castStrictTMVar
  , newTMVar
  , newTMVarIO
  , newEmptyTMVarIO
  , newEmptyTMVar
  , takeTMVar
  , tryTakeTMVar
  , putTMVar
  , tryPutTMVar
  , readTMVar
  , tryReadTMVar
  , swapTMVar
  , isEmptyTMVar
    -- ** Low-level API
  , checkInvariant
    -- * Deprecated API
  , updateTVar
  , newTVarM
  , newTVarWithInvariantM
  , newTMVarM
  , newEmptyTMVarM
  ) where

import           Control.Monad.Class.MonadSTM as X hiding (LazyTMVar, LazyTVar,
                     TMVar, TVar, isEmptyTMVar, modifyTVar, newEmptyTMVar,
                     newEmptyTMVarIO, newEmptyTMVarM, newTMVar, newTMVarIO,
                     newTMVarM, newTVar, newTVarIO, newTVarM, putTMVar,
                     readTMVar, readTVar, stateTVar, swapTMVar, takeTMVar,
                     tryPutTMVar, tryReadTMVar, tryTakeTMVar, writeTVar)
import qualified Control.Monad.Class.MonadSTM as Lazy
import           GHC.Stack

{-------------------------------------------------------------------------------
  Lazy TVar
-------------------------------------------------------------------------------}

type LazyTVar  m = Lazy.TVar m
type LazyTMVar m = Lazy.TMVar m

{-------------------------------------------------------------------------------
  Strict TVar
-------------------------------------------------------------------------------}

data StrictTVar m a = StrictTVar
   { forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
invariant :: !(a -> Maybe String)
     -- ^ Invariant checked whenever updating the 'StrictTVar'.
   , forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
tvar      :: !(LazyTVar m a)
   }

castStrictTVar :: LazyTVar m ~ LazyTVar n
               => StrictTVar m a -> StrictTVar n a
castStrictTVar :: forall (m :: * -> *) (n :: * -> *) a.
(LazyTVar m ~ LazyTVar n) =>
StrictTVar m a -> StrictTVar n a
castStrictTVar StrictTVar{a -> Maybe String
invariant :: a -> Maybe String
$sel:invariant:StrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
invariant, LazyTVar m a
tvar :: LazyTVar m a
$sel:tvar:StrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
tvar} = StrictTVar :: forall (m :: * -> *) a.
(a -> Maybe String) -> LazyTVar m a -> StrictTVar m a
StrictTVar{a -> Maybe String
invariant :: a -> Maybe String
$sel:invariant:StrictTVar :: a -> Maybe String
invariant, LazyTVar m a
TVar_ (STM n) a
tvar :: LazyTVar m a
$sel:tvar:StrictTVar :: TVar_ (STM n) a
tvar}

-- | Get the underlying @TVar@
--
-- Since we obviously cannot guarantee that updates to this 'LazyTVar' will be
-- strict, this should be used with caution.
toLazyTVar :: StrictTVar m a -> LazyTVar m a
toLazyTVar :: forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
toLazyTVar StrictTVar { LazyTVar m a
tvar :: LazyTVar m a
$sel:tvar:StrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
tvar } = LazyTVar m a
tvar

newTVar :: MonadSTM m => a -> STM m (StrictTVar m a)
newTVar :: forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTVar m a)
newTVar !a
a = (a -> Maybe String) -> TVar_ (STM m) a -> StrictTVar m a
forall (m :: * -> *) a.
(a -> Maybe String) -> LazyTVar m a -> StrictTVar m a
StrictTVar (Maybe String -> a -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing) (TVar_ (STM m) a -> StrictTVar m a)
-> STM m (TVar_ (STM m) a) -> STM m (StrictTVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> STM m (TVar_ (STM m) a)
forall (stm :: * -> *) a. MonadSTMTx stm => a -> stm (TVar_ stm a)
Lazy.newTVar a
a

newTVarIO :: MonadSTM m => a -> m (StrictTVar m a)
newTVarIO :: forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO = (a -> Maybe String) -> a -> m (StrictTVar m a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> m (StrictTVar m a)
newTVarWithInvariantIO (Maybe String -> a -> Maybe String
forall a b. a -> b -> a
const Maybe String
forall a. Maybe a
Nothing)

newTVarM :: MonadSTM m => a -> m (StrictTVar m a)
newTVarM :: forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarM = a -> m (StrictTVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTVar m a)
newTVarIO
{-# DEPRECATED newTVarM "Use newTVarIO" #-}

newTVarWithInvariantIO :: (MonadSTM m, HasCallStack)
                       => (a -> Maybe String) -- ^ Invariant (expect 'Nothing')
                       -> a
                       -> m (StrictTVar m a)
newTVarWithInvariantIO :: forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> m (StrictTVar m a)
newTVarWithInvariantIO a -> Maybe String
invariant !a
a =
    Maybe String -> m (StrictTVar m a) -> m (StrictTVar m a)
forall a. HasCallStack => Maybe String -> a -> a
checkInvariant (a -> Maybe String
invariant a
a) (m (StrictTVar m a) -> m (StrictTVar m a))
-> m (StrictTVar m a) -> m (StrictTVar m a)
forall a b. (a -> b) -> a -> b
$
    (a -> Maybe String) -> TVar_ (STM m) a -> StrictTVar m a
forall (m :: * -> *) a.
(a -> Maybe String) -> LazyTVar m a -> StrictTVar m a
StrictTVar a -> Maybe String
invariant (TVar_ (STM m) a -> StrictTVar m a)
-> m (TVar_ (STM m) a) -> m (StrictTVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> m (TVar_ (STM m) a)
forall (m :: * -> *) a. MonadSTM m => a -> m (TVar m a)
Lazy.newTVarIO a
a

newTVarWithInvariantM :: (MonadSTM m, HasCallStack)
                      => (a -> Maybe String) -- ^ Invariant (expect 'Nothing')
                      -> a
                      -> m (StrictTVar m a)
newTVarWithInvariantM :: forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> m (StrictTVar m a)
newTVarWithInvariantM = (a -> Maybe String) -> a -> m (StrictTVar m a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
(a -> Maybe String) -> a -> m (StrictTVar m a)
newTVarWithInvariantIO
{-# DEPRECATED newTVarWithInvariantM "Use newTVarWithInvariantIO" #-}

readTVar :: MonadSTM m => StrictTVar m a -> STM m a
readTVar :: forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar { LazyTVar m a
tvar :: LazyTVar m a
$sel:tvar:StrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
tvar } = LazyTVar m a -> STM m a
forall (stm :: * -> *) a. MonadSTMTx stm => TVar_ stm a -> stm a
Lazy.readTVar LazyTVar m a
tvar

writeTVar :: (MonadSTM m, HasCallStack) => StrictTVar m a -> a -> STM m ()
writeTVar :: forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar { LazyTVar m a
tvar :: LazyTVar m a
$sel:tvar:StrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> LazyTVar m a
tvar, a -> Maybe String
invariant :: a -> Maybe String
$sel:invariant:StrictTVar :: forall (m :: * -> *) a. StrictTVar m a -> a -> Maybe String
invariant } !a
a =
    Maybe String -> STM m () -> STM m ()
forall a. HasCallStack => Maybe String -> a -> a
checkInvariant (a -> Maybe String
invariant a
a) (STM m () -> STM m ()) -> STM m () -> STM m ()
forall a b. (a -> b) -> a -> b
$
    LazyTVar m a -> a -> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TVar_ stm a -> a -> stm ()
Lazy.writeTVar LazyTVar m a
tvar a
a

modifyTVar :: MonadSTM m => StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTVar m a -> (a -> a) -> STM m ()
modifyTVar StrictTVar m a
v a -> a
f = StrictTVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m a
v STM m a -> (a -> STM m ()) -> STM m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StrictTVar m a -> a -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m a
v (a -> STM m ()) -> (a -> a) -> a -> STM m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f

stateTVar :: MonadSTM m => StrictTVar m a -> (a -> (a, b)) -> STM m b
stateTVar :: forall (m :: * -> *) a b.
MonadSTM m =>
StrictTVar m a -> (a -> (a, b)) -> STM m b
stateTVar StrictTVar m a
v a -> (a, b)
f = do
    a
a <- StrictTVar m a -> STM m a
forall (m :: * -> *) a. MonadSTM m => StrictTVar m a -> STM m a
readTVar StrictTVar m a
v
    let (a
a', b
b) = a -> (a, b)
f a
a
    StrictTVar m a -> a -> STM m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
StrictTVar m a -> a -> STM m ()
writeTVar StrictTVar m a
v a
a'
    b -> STM m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
b

updateTVar :: MonadSTM m => StrictTVar m a -> (a -> (a, b)) -> STM m b
updateTVar :: forall (m :: * -> *) a b.
MonadSTM m =>
StrictTVar m a -> (a -> (a, b)) -> STM m b
updateTVar = StrictTVar m a -> (a -> (a, b)) -> STM m b
forall (m :: * -> *) a b.
MonadSTM m =>
StrictTVar m a -> (a -> (a, b)) -> STM m b
stateTVar
{-# DEPRECATED updateTVar "Use stateTVar" #-}

{-------------------------------------------------------------------------------
  Strict TMVar
-------------------------------------------------------------------------------}

-- 'TMVar' that keeps its value in WHNF at all times
--
-- Does not support an invariant: if the invariant would not be satisfied,
-- we would not be able to put a value into an empty TMVar, which would lead
-- to very hard to debug bugs where code is blocked indefinitely.
newtype StrictTMVar m a = StrictTMVar (LazyTMVar m a)

castStrictTMVar :: LazyTMVar m ~ LazyTMVar n
                => StrictTMVar m a -> StrictTMVar n a
castStrictTMVar :: forall (m :: * -> *) (n :: * -> *) a.
(LazyTMVar m ~ LazyTMVar n) =>
StrictTMVar m a -> StrictTMVar n a
castStrictTMVar (StrictTMVar LazyTMVar m a
var) = LazyTMVar n a -> StrictTMVar n a
forall (m :: * -> *) a. LazyTMVar m a -> StrictTMVar m a
StrictTMVar LazyTMVar m a
LazyTMVar n a
var

newTMVar :: MonadSTM m => a -> STM m (StrictTMVar m a)
newTMVar :: forall (m :: * -> *) a. MonadSTM m => a -> STM m (StrictTMVar m a)
newTMVar !a
a = TMVar_ (STM m) a -> StrictTMVar m a
forall (m :: * -> *) a. LazyTMVar m a -> StrictTMVar m a
StrictTMVar (TMVar_ (STM m) a -> StrictTMVar m a)
-> STM m (TMVar_ (STM m) a) -> STM m (StrictTMVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> STM m (TMVar_ (STM m) a)
forall (stm :: * -> *) a. MonadSTMTx stm => a -> stm (TMVar_ stm a)
Lazy.newTMVar a
a

newTMVarIO :: MonadSTM m => a -> m (StrictTMVar m a)
newTMVarIO :: forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTMVar m a)
newTMVarIO !a
a = TMVar_ (STM m) a -> StrictTMVar m a
forall (m :: * -> *) a. LazyTMVar m a -> StrictTMVar m a
StrictTMVar (TMVar_ (STM m) a -> StrictTMVar m a)
-> m (TMVar_ (STM m) a) -> m (StrictTMVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> m (TMVar_ (STM m) a)
forall (m :: * -> *) a. MonadSTM m => a -> m (TMVar m a)
Lazy.newTMVarIO a
a

newTMVarM :: MonadSTM m => a -> m (StrictTMVar m a)
newTMVarM :: forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTMVar m a)
newTMVarM = a -> m (StrictTMVar m a)
forall (m :: * -> *) a. MonadSTM m => a -> m (StrictTMVar m a)
newTMVarIO
{-# DEPRECATED newTMVarM "Use newTVarIO" #-}

newEmptyTMVar :: MonadSTM m => STM m (StrictTMVar m a)
newEmptyTMVar :: forall (m :: * -> *) a. MonadSTM m => STM m (StrictTMVar m a)
newEmptyTMVar = TMVar_ (STM m) a -> StrictTMVar m a
forall (m :: * -> *) a. LazyTMVar m a -> StrictTMVar m a
StrictTMVar (TMVar_ (STM m) a -> StrictTMVar m a)
-> STM m (TMVar_ (STM m) a) -> STM m (StrictTMVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m (TMVar_ (STM m) a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm (TMVar_ stm a)
Lazy.newEmptyTMVar

newEmptyTMVarIO :: MonadSTM m => m (StrictTMVar m a)
newEmptyTMVarIO :: forall (m :: * -> *) a. MonadSTM m => m (StrictTMVar m a)
newEmptyTMVarIO = TMVar_ (STM m) a -> StrictTMVar m a
forall (m :: * -> *) a. LazyTMVar m a -> StrictTMVar m a
StrictTMVar (TMVar_ (STM m) a -> StrictTMVar m a)
-> m (TMVar_ (STM m) a) -> m (StrictTMVar m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (TMVar_ (STM m) a)
forall (m :: * -> *) a. MonadSTM m => m (TMVar m a)
Lazy.newEmptyTMVarIO

newEmptyTMVarM :: MonadSTM m => m (StrictTMVar m a)
newEmptyTMVarM :: forall (m :: * -> *) a. MonadSTM m => m (StrictTMVar m a)
newEmptyTMVarM = m (StrictTMVar m a)
forall (m :: * -> *) a. MonadSTM m => m (StrictTMVar m a)
newEmptyTMVarIO
{-# DEPRECATED newEmptyTMVarM "Use newEmptyTMVarIO" #-}

takeTMVar :: MonadSTM m => StrictTMVar m a -> STM m a
takeTMVar :: forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
takeTMVar (StrictTMVar LazyTMVar m a
tmvar) = LazyTMVar m a -> STM m a
forall (stm :: * -> *) a. MonadSTMTx stm => TMVar_ stm a -> stm a
Lazy.takeTMVar LazyTMVar m a
tmvar

tryTakeTMVar :: MonadSTM m => StrictTMVar m a -> STM m (Maybe a)
tryTakeTMVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> STM m (Maybe a)
tryTakeTMVar (StrictTMVar LazyTMVar m a
tmvar) = LazyTMVar m a -> STM m (Maybe a)
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> stm (Maybe a)
Lazy.tryTakeTMVar LazyTMVar m a
tmvar

putTMVar :: MonadSTM m => StrictTMVar m a -> a -> STM m ()
putTMVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m ()
putTMVar (StrictTMVar LazyTMVar m a
tmvar) !a
a = LazyTMVar m a -> a -> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> a -> stm ()
Lazy.putTMVar LazyTMVar m a
tmvar a
a

tryPutTMVar :: MonadSTM m => StrictTMVar m a -> a -> STM m Bool
tryPutTMVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m Bool
tryPutTMVar (StrictTMVar LazyTMVar m a
tmvar) !a
a = LazyTMVar m a -> a -> STM m Bool
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> a -> stm Bool
Lazy.tryPutTMVar LazyTMVar m a
tmvar a
a

readTMVar :: MonadSTM m => StrictTMVar m a -> STM m a
readTMVar :: forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m a
readTMVar (StrictTMVar LazyTMVar m a
tmvar) = LazyTMVar m a -> STM m a
forall (stm :: * -> *) a. MonadSTMTx stm => TMVar_ stm a -> stm a
Lazy.readTMVar LazyTMVar m a
tmvar

tryReadTMVar :: MonadSTM m => StrictTMVar m a -> STM m (Maybe a)
tryReadTMVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> STM m (Maybe a)
tryReadTMVar (StrictTMVar LazyTMVar m a
tmvar) = LazyTMVar m a -> STM m (Maybe a)
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> stm (Maybe a)
Lazy.tryReadTMVar LazyTMVar m a
tmvar

swapTMVar :: MonadSTM m => StrictTMVar m a -> a -> STM m a
swapTMVar :: forall (m :: * -> *) a.
MonadSTM m =>
StrictTMVar m a -> a -> STM m a
swapTMVar (StrictTMVar LazyTMVar m a
tmvar) !a
a = LazyTMVar m a -> a -> STM m a
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> a -> stm a
Lazy.swapTMVar LazyTMVar m a
tmvar a
a

isEmptyTMVar :: MonadSTM m => StrictTMVar m a -> STM m Bool
isEmptyTMVar :: forall (m :: * -> *) a. MonadSTM m => StrictTMVar m a -> STM m Bool
isEmptyTMVar (StrictTMVar LazyTMVar m a
tmvar) = LazyTMVar m a -> STM m Bool
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> stm Bool
Lazy.isEmptyTMVar LazyTMVar m a
tmvar

{-------------------------------------------------------------------------------
  Dealing with invariants
-------------------------------------------------------------------------------}

-- | Check invariant (if enabled) before continuing
--
-- @checkInvariant mErr x@ is equal to @x@ if @mErr == Nothing@, and throws
-- an error @err@ if @mErr == Just err@.
--
-- This is exported so that other code that wants to conditionally check
-- invariants can reuse the same logic, rather than having to introduce new
-- per-package flags.
checkInvariant :: HasCallStack => Maybe String -> a -> a
#if CHECK_TVAR_INVARIANT
checkInvariant Nothing    k = k
checkInvariant (Just err) _ = error $ "Invariant violation: " ++ err
#else
checkInvariant :: forall a. HasCallStack => Maybe String -> a -> a
checkInvariant Maybe String
_err a
k       = a
k
#endif