{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Counter (
    Counter,
    newCounter,
    waitForZero,
    increase,
    decrease,
    waitForDecreased,
) where

import Control.Concurrent.STM

import Network.Wai.Handler.Warp.Imports

newtype Counter = Counter (TVar Int)

newCounter :: IO Counter
newCounter :: IO Counter
newCounter = TVar Int -> Counter
Counter (TVar Int -> Counter) -> IO (TVar Int) -> IO Counter
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0

waitForZero :: Counter -> IO ()
waitForZero :: Counter -> IO ()
waitForZero (Counter TVar Int
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    x <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
var
    when (x > 0) retry

waitForDecreased :: Counter -> IO ()
waitForDecreased :: Counter -> IO ()
waitForDecreased (Counter TVar Int
var) = do
    n0 <- STM Int -> IO Int
forall a. STM a -> IO a
atomically (STM Int -> IO Int) -> STM Int -> IO Int
forall a b. (a -> b) -> a -> b
$ TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
var
    atomically $ do
        n <- readTVar var
        check (n < n0)

increase :: Counter -> IO ()
increase :: Counter -> IO ()
increase (Counter TVar Int
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
var ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

decrease :: Counter -> IO ()
decrease :: Counter -> IO ()
decrease (Counter TVar Int
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
var ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1