{-# LANGUAGE RankNTypes #-}

-- |
-- Copyright: © 2018-2020 IOHK
-- License: Apache-2.0
--
-- Provides functions for setting up and capturing logging so that expectations
-- about logging can be asserted in test scenarios.

module Test.Utils.Trace
     ( withLogging
     , captureLogging
     , countMsg
     ) where

import Prelude

import Cardano.BM.Trace
    ( traceInTVarIO )
import Control.Concurrent.STM.TVar
    ( newTVarIO, readTVarIO )
import Control.Tracer
    ( Tracer )
import Control.Lens.Prism
    ( Prism' )
import Control.Lens.Operators
    ( (^?) )

import Data.Maybe
    ( isJust )

-- | Run an action with a logging 'Trace' object, and a function to get all
-- messages that have been traced.
withLogging :: ((Tracer IO msg, IO [msg]) -> IO a) -> IO a
withLogging :: ((Tracer IO msg, IO [msg]) -> IO a) -> IO a
withLogging (Tracer IO msg, IO [msg]) -> IO a
action = do
    TVar [msg]
tvar <- [msg] -> IO (TVar [msg])
forall a. a -> IO (TVar a)
newTVarIO []
    let getMsgs :: IO [msg]
getMsgs = [msg] -> [msg]
forall a. [a] -> [a]
reverse ([msg] -> [msg]) -> IO [msg] -> IO [msg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar [msg] -> IO [msg]
forall a. TVar a -> IO a
readTVarIO TVar [msg]
tvar
    (Tracer IO msg, IO [msg]) -> IO a
action (TVar [msg] -> Tracer IO msg
forall a. TVar [a] -> Tracer IO a
traceInTVarIO TVar [msg]
tvar, IO [msg]
getMsgs)

-- | Run an action with a 'Trace', returning captured log messages along with
-- the result of the action.
captureLogging :: (Tracer IO msg -> IO a) -> IO ([msg], a)
captureLogging :: (Tracer IO msg -> IO a) -> IO ([msg], a)
captureLogging Tracer IO msg -> IO a
action = ((Tracer IO msg, IO [msg]) -> IO ([msg], a)) -> IO ([msg], a)
forall msg a. ((Tracer IO msg, IO [msg]) -> IO a) -> IO a
withLogging (((Tracer IO msg, IO [msg]) -> IO ([msg], a)) -> IO ([msg], a))
-> ((Tracer IO msg, IO [msg]) -> IO ([msg], a)) -> IO ([msg], a)
forall a b. (a -> b) -> a -> b
$ \(Tracer IO msg
tr, IO [msg]
getMsgs) -> do
    a
res <- Tracer IO msg -> IO a
action Tracer IO msg
tr
    [msg]
msgs <- IO [msg]
getMsgs
    ([msg], a) -> IO ([msg], a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([msg]
msgs, a
res)

-- | Count elements in the list matching the given Prism. Handy for counting log
-- messages which are typically constructed as sum types with many constructors.
--
-- A Prism look scary but can be obtained very easily if the target type is
-- deriving 'Generic'. From there, use
--
-- `Data.Generics.Sum.Constructor#_Ctor` from `generic-lens`.
--
-- __Example:__
--
-- >>> data MySumType = MyConstructor | MyOtherConstructor deriving Generic
--
-- >>> xs = [ MyConstructor, MyOtherConstructor, MyConstructor ]
--
-- >>> count (_Ctor @"MyConstructor") xs
-- 2
countMsg :: Prism' s a -> [s] -> Int
countMsg :: Prism' s a -> [s] -> Int
countMsg Prism' s a
prism = [s] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([s] -> Int) -> ([s] -> [s]) -> [s] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s -> Bool) -> [s] -> [s]
forall a. (a -> Bool) -> [a] -> [a]
filter (\s
x -> Maybe a -> Bool
forall a. Maybe a -> Bool
isJust (s
x s -> Getting (First a) s a -> Maybe a
forall s a. s -> Getting (First a) s a -> Maybe a
^? Getting (First a) s a
Prism' s a
prism))