{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Traces of transition systems and associated operators.
--
-- This module also includes a minimal domain-specific-language to specify
-- expectations on traces.
--
module Control.State.Transition.Trace
  ( -- * Trace checking
    (.-)
  , (.->)
  , checkTrace
    -- * Trace
  , Trace (..)
  , TraceOrder (NewestFirst, OldestFirst)
  , mkTrace
  , traceEnv
  , traceInitState
  , traceSignals
  , traceStates
  , preStatesAndSignals
  , SourceSignalTarget (..)
  , sourceSignalTargets
  , traceLength
  , traceInit
  , lastState
  , lastSignal
  , firstAndLastState
  , closure
  -- * Miscellaneous utilities
  , extractValues
  , applySTSTest
  )
where

import           Control.Monad (void)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Monad.Reader (MonadReader, ReaderT, ask, runReaderT)
import           Data.Data (Data, Typeable, cast, gmapQ)
import           Data.Functor ((<&>))
import           Data.Maybe (catMaybes)
import           Data.Sequence.Strict (StrictSeq ((:<|), Empty))
import qualified Data.Sequence.Strict as StrictSeq
import           GHC.Generics (Generic)
import           GHC.Stack (HasCallStack)
import           Lens.Micro (Lens', lens, to, (^.), (^..))
import           Lens.Micro.TH (makeLenses)
import           NoThunks.Class (NoThunks(..))
import           Test.Tasty.HUnit (assertFailure, (@?=))

import           Control.State.Transition.Extended hiding (Assertion, trans)

-- Signal and resulting state.
--
-- Strict in both arguments, unlike a tuple.
data SigState s = SigState !(State s) !(Signal s)
  deriving ((forall x. SigState s -> Rep (SigState s) x)
-> (forall x. Rep (SigState s) x -> SigState s)
-> Generic (SigState s)
forall x. Rep (SigState s) x -> SigState s
forall x. SigState s -> Rep (SigState s) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall s x. Rep (SigState s) x -> SigState s
forall s x. SigState s -> Rep (SigState s) x
$cto :: forall s x. Rep (SigState s) x -> SigState s
$cfrom :: forall s x. SigState s -> Rep (SigState s) x
Generic)

transSt :: Lens' (SigState s) (State s)
transSt :: (State s -> f (State s)) -> SigState s -> f (SigState s)
transSt = (SigState s -> State s)
-> (SigState s -> State s -> SigState s)
-> Lens (SigState s) (SigState s) (State s) (State s)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens (\(SigState State s
st Signal s
_) -> State s
st) (\(SigState State s
_ Signal s
x) State s
st -> State s -> Signal s -> SigState s
forall s. State s -> Signal s -> SigState s
SigState State s
st Signal s
x)

transSig :: Lens' (SigState s) (Signal s)
transSig :: (Signal s -> f (Signal s)) -> SigState s -> f (SigState s)
transSig = (SigState s -> Signal s)
-> (SigState s -> Signal s -> SigState s)
-> Lens (SigState s) (SigState s) (Signal s) (Signal s)
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens (\(SigState State s
_ Signal s
sig) -> Signal s
sig) (\(SigState State s
x Signal s
_) Signal s
sig -> State s -> Signal s -> SigState s
forall s. State s -> Signal s -> SigState s
SigState State s
x Signal s
sig)

deriving instance
  (Eq (State s), Eq (Signal s)) => (Eq (SigState s))

deriving instance
  (Show (State s), Show (Signal s)) => (Show (SigState s))

instance
  ( NoThunks (State s)
  , NoThunks (Signal s)
  ) => (NoThunks (SigState s))

-- | A successful trace of a transition system.
--
data Trace s
  = Trace
    { Trace s -> Environment s
_traceEnv :: !(Environment s)
      -- ^ Environment under which the trace was run.
      , Trace s -> State s
_traceInitState :: !(State s)
      -- ^ Initial state in the trace
      , Trace s -> StrictSeq (SigState s)
_traceTrans :: !(StrictSeq (SigState s))
      -- ^ Signals and resulting states observed in the trace. New elements are
      -- put in front of the list.
    } deriving (forall x. Trace s -> Rep (Trace s) x)
-> (forall x. Rep (Trace s) x -> Trace s) -> Generic (Trace s)
forall x. Rep (Trace s) x -> Trace s
forall x. Trace s -> Rep (Trace s) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall s x. Rep (Trace s) x -> Trace s
forall s x. Trace s -> Rep (Trace s) x
$cto :: forall s x. Rep (Trace s) x -> Trace s
$cfrom :: forall s x. Trace s -> Rep (Trace s) x
Generic

makeLenses ''Trace

deriving instance
  (Eq (State s), Eq (Signal s), Eq (Environment s)) => (Eq (Trace s))

deriving instance
  (Show (State s), Show (Signal s), Show (Environment s)) => (Show (Trace s))

instance
  ( NoThunks (Environment s)
  , NoThunks (State s)
  , NoThunks (Signal s)
  ) => (NoThunks (Trace s))

-- | Make a trace given an environment and initial state.
mkTrace :: Environment s -> State s -> [(State s, Signal s)] -> Trace s
mkTrace :: Environment s -> State s -> [(State s, Signal s)] -> Trace s
mkTrace Environment s
env State s
initState [(State s, Signal s)]
sigs = Environment s -> State s -> StrictSeq (SigState s) -> Trace s
forall s.
Environment s -> State s -> StrictSeq (SigState s) -> Trace s
Trace Environment s
env State s
initState StrictSeq (SigState s)
sigs'
  where
    sigs' :: StrictSeq (SigState s)
sigs' = (State s -> Signal s -> SigState s)
-> (State s, Signal s) -> SigState s
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry State s -> Signal s -> SigState s
forall s. State s -> Signal s -> SigState s
SigState ((State s, Signal s) -> SigState s)
-> StrictSeq (State s, Signal s) -> StrictSeq (SigState s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(State s, Signal s)] -> StrictSeq (State s, Signal s)
forall a. [a] -> StrictSeq a
StrictSeq.fromList [(State s, Signal s)]
sigs

-- $setup
-- |
-- >>> :set -XTypeFamilies
-- >>> import Control.State.Transition (initialRules, transitionRules)
-- >>> :{
-- data DUMMY
-- data DummyPredicateFailure = CeciNEstPasUnePredicateFailure deriving (Eq, Show)
-- instance STS DUMMY where
--   type Environment DUMMY = Bool
--   type State DUMMY = Int
--   type Signal DUMMY = String
--   type PredicateFailure DUMMY = DummyPredicateFailure
--   initialRules = []
--   transitionRules = []
-- :}

-- | Extract the last state of a trace. Since a trace has at least an initial
-- state, the last state of a trace is always defined.
--
-- Examples:
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> lastState tr0
-- 0
--
-- >>> tr01 = mkTrace True 0 [(1, "one")] :: Trace DUMMY
-- >>> lastState tr01
-- 1
--
-- >>> tr012 = mkTrace True 0 [(2, "two"), (1, "one")] :: Trace DUMMY
-- >>> lastState tr012
-- 2
--
lastState :: Trace s -> State s
lastState :: Trace s -> State s
lastState Trace { State s
_traceInitState :: State s
_traceInitState :: forall s. Trace s -> State s
_traceInitState, StrictSeq (SigState s)
_traceTrans :: StrictSeq (SigState s)
_traceTrans :: forall s. Trace s -> StrictSeq (SigState s)
_traceTrans } =
  case StrictSeq (SigState s)
_traceTrans of
    SigState State s
st Signal s
_ :<| StrictSeq (SigState s)
_ -> State s
st
    StrictSeq (SigState s)
_ -> State s
_traceInitState


-- | Get the last applied signal in a trace (this is, the newest signal).
--
--
-- Examples:
--
-- >>> :set -XScopedTypeVariables
-- >>> import Control.Exception (catch, ErrorCall)
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> print (lastSignal tr0) `catch` (\(_ :: ErrorCall) -> putStrLn "error!")
-- "error!
--
-- dnadales: In the example above I don't know why the doctests is swallowing
-- the last @"@.
--
-- >>> tr01 = mkTrace True 0 [(1, "one")] :: Trace DUMMY
-- >>> lastSignal tr01
-- "one"
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> lastSignal tr0123
-- "three"
--
lastSignal :: HasCallStack => Trace s -> Signal s
lastSignal :: Trace s -> Signal s
lastSignal Trace { StrictSeq (SigState s)
_traceTrans :: StrictSeq (SigState s)
_traceTrans :: forall s. Trace s -> StrictSeq (SigState s)
_traceTrans } =
  case StrictSeq (SigState s)
_traceTrans of
    StrictSeq (SigState s)
Empty -> String -> Signal s
forall a. HasCallStack => String -> a
error String
"lastSignal was called with a trace without signals"
    SigState State s
_st Signal s
signal :<| StrictSeq (SigState s)
_ -> Signal s
signal


-- | Return the first and last state of the trace.
--
-- The first state is returned in the first component of the result tuple.
--
-- Examples:
--
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> firstAndLastState tr0
-- (0,0)
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> firstAndLastState tr0123
-- (0,3)
--
firstAndLastState :: Trace s -> (State s, State s)
firstAndLastState :: Trace s -> (State s, State s)
firstAndLastState Trace s
tr = (Trace s -> State s
forall s. Trace s -> State s
_traceInitState Trace s
tr, Trace s -> State s
forall s. Trace s -> State s
lastState Trace s
tr)


data TraceOrder = NewestFirst | OldestFirst deriving (TraceOrder -> TraceOrder -> Bool
(TraceOrder -> TraceOrder -> Bool)
-> (TraceOrder -> TraceOrder -> Bool) -> Eq TraceOrder
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TraceOrder -> TraceOrder -> Bool
$c/= :: TraceOrder -> TraceOrder -> Bool
== :: TraceOrder -> TraceOrder -> Bool
$c== :: TraceOrder -> TraceOrder -> Bool
Eq)

fromNewestFirst :: TraceOrder -> [a] -> [a]
fromNewestFirst :: TraceOrder -> [a] -> [a]
fromNewestFirst TraceOrder
NewestFirst = [a] -> [a]
forall a. a -> a
id
fromNewestFirst TraceOrder
OldestFirst = [a] -> [a]
forall a. [a] -> [a]
reverse

-- | Retrieve all the signals in the trace, in the order specified.
--
-- Examples:
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> traceSignals NewestFirst tr0
-- []
--
-- >>> tr01 = mkTrace True 0 [(1, "one")] :: Trace DUMMY
-- >>> traceSignals NewestFirst tr01
-- ["one"]
--
-- >>> traceSignals OldestFirst tr01
-- ["one"]
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> traceSignals NewestFirst tr0123
-- ["three","two","one"]
--
-- >>> traceSignals OldestFirst tr0123
-- ["one","two","three"]
--
traceSignals :: TraceOrder -> Trace s -> [Signal s]
traceSignals :: TraceOrder -> Trace s -> [Signal s]
traceSignals TraceOrder
order Trace s
tr = TraceOrder -> [Signal s] -> [Signal s]
forall a. TraceOrder -> [a] -> [a]
fromNewestFirst TraceOrder
order (Trace s
tr Trace s
-> Getting (Endo [Signal s]) (Trace s) (Signal s) -> [Signal s]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. (StrictSeq (SigState s)
 -> Const (Endo [Signal s]) (StrictSeq (SigState s)))
-> Trace s -> Const (Endo [Signal s]) (Trace s)
forall s. Lens' (Trace s) (StrictSeq (SigState s))
traceTrans ((StrictSeq (SigState s)
  -> Const (Endo [Signal s]) (StrictSeq (SigState s)))
 -> Trace s -> Const (Endo [Signal s]) (Trace s))
-> ((Signal s -> Const (Endo [Signal s]) (Signal s))
    -> StrictSeq (SigState s)
    -> Const (Endo [Signal s]) (StrictSeq (SigState s)))
-> Getting (Endo [Signal s]) (Trace s) (Signal s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SigState s -> Const (Endo [Signal s]) (SigState s))
-> StrictSeq (SigState s)
-> Const (Endo [Signal s]) (StrictSeq (SigState s))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SigState s -> Const (Endo [Signal s]) (SigState s))
 -> StrictSeq (SigState s)
 -> Const (Endo [Signal s]) (StrictSeq (SigState s)))
-> ((Signal s -> Const (Endo [Signal s]) (Signal s))
    -> SigState s -> Const (Endo [Signal s]) (SigState s))
-> (Signal s -> Const (Endo [Signal s]) (Signal s))
-> StrictSeq (SigState s)
-> Const (Endo [Signal s]) (StrictSeq (SigState s))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Signal s -> Const (Endo [Signal s]) (Signal s))
-> SigState s -> Const (Endo [Signal s]) (SigState s)
forall s. Lens' (SigState s) (Signal s)
transSig)

-- | Retrieve all the states in the trace, in the order specified.
--
-- Examples:
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> traceStates NewestFirst tr0
-- [0]
--
-- >>> traceStates OldestFirst tr0
-- [0]
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> traceStates NewestFirst tr0123
-- [3,2,1,0]
--
-- >>> traceStates OldestFirst tr0123
-- [0,1,2,3]
--
traceStates :: TraceOrder -> Trace s -> [State s]
traceStates :: TraceOrder -> Trace s -> [State s]
traceStates TraceOrder
order Trace s
tr = TraceOrder -> [State s] -> [State s]
forall a. TraceOrder -> [a] -> [a]
fromNewestFirst TraceOrder
order ([State s]
xs [State s] -> [State s] -> [State s]
forall a. [a] -> [a] -> [a]
++ [State s
x])
  where
    x :: State s
x = Trace s
tr Trace s -> Getting (State s) (Trace s) (State s) -> State s
forall s a. s -> Getting a s a -> a
^. Getting (State s) (Trace s) (State s)
forall s. Lens' (Trace s) (State s)
traceInitState
    xs :: [State s]
xs = Trace s
tr Trace s
-> Getting (Endo [State s]) (Trace s) (State s) -> [State s]
forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. (StrictSeq (SigState s)
 -> Const (Endo [State s]) (StrictSeq (SigState s)))
-> Trace s -> Const (Endo [State s]) (Trace s)
forall s. Lens' (Trace s) (StrictSeq (SigState s))
traceTrans ((StrictSeq (SigState s)
  -> Const (Endo [State s]) (StrictSeq (SigState s)))
 -> Trace s -> Const (Endo [State s]) (Trace s))
-> ((State s -> Const (Endo [State s]) (State s))
    -> StrictSeq (SigState s)
    -> Const (Endo [State s]) (StrictSeq (SigState s)))
-> Getting (Endo [State s]) (Trace s) (State s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SigState s -> Const (Endo [State s]) (SigState s))
-> StrictSeq (SigState s)
-> Const (Endo [State s]) (StrictSeq (SigState s))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((SigState s -> Const (Endo [State s]) (SigState s))
 -> StrictSeq (SigState s)
 -> Const (Endo [State s]) (StrictSeq (SigState s)))
-> ((State s -> Const (Endo [State s]) (State s))
    -> SigState s -> Const (Endo [State s]) (SigState s))
-> (State s -> Const (Endo [State s]) (State s))
-> StrictSeq (SigState s)
-> Const (Endo [State s]) (StrictSeq (SigState s))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State s -> Const (Endo [State s]) (State s))
-> SigState s -> Const (Endo [State s]) (SigState s)
forall s. Lens' (SigState s) (State s)
transSt

-- | Compute the length of a trace, defined as the number of signals it
-- contains.
--
-- Examples:
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> traceLength tr0
-- 0
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> traceLength tr0123
-- 3
--
traceLength :: Trace s -> Int
traceLength :: Trace s -> Int
traceLength Trace s
tr = Trace s
tr Trace s -> Getting Int (Trace s) Int -> Int
forall s a. s -> Getting a s a -> a
^. (StrictSeq (SigState s) -> Const Int (StrictSeq (SigState s)))
-> Trace s -> Const Int (Trace s)
forall s. Lens' (Trace s) (StrictSeq (SigState s))
traceTrans ((StrictSeq (SigState s) -> Const Int (StrictSeq (SigState s)))
 -> Trace s -> Const Int (Trace s))
-> ((Int -> Const Int Int)
    -> StrictSeq (SigState s) -> Const Int (StrictSeq (SigState s)))
-> Getting Int (Trace s) Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StrictSeq (SigState s) -> Int)
-> SimpleGetter (StrictSeq (SigState s)) Int
forall s a. (s -> a) -> SimpleGetter s a
to StrictSeq (SigState s) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length

-- | Take all but the newest signal in the trace.
--
-- Precondition: the trace must contain at least one signal
--
-- Examples:
--
--
-- >>> :set -XScopedTypeVariables
-- >>> import Control.Exception (catch, ErrorCall)
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> print (traceInit tr0) `catch` (\(_ :: ErrorCall) -> print "error!")
-- "error!"
--
-- >>> tr01 = mkTrace True 0 [(1, "one")] :: Trace DUMMY
-- >>> traceInit tr01
-- Trace {_traceEnv = True, _traceInitState = 0, _traceTrans = StrictSeq {fromStrict = fromList []}}
--
-- >>> tr012 = mkTrace True 0 [(2, "two"), (1, "one")] :: Trace DUMMY
-- >>> traceInit tr012
-- Trace {_traceEnv = True, _traceInitState = 0, _traceTrans = StrictSeq {fromStrict = fromList [SigState 1 "one"]}}
--
traceInit :: HasCallStack => Trace s -> Trace s
traceInit :: Trace s -> Trace s
traceInit tr :: Trace s
tr@Trace { StrictSeq (SigState s)
_traceTrans :: StrictSeq (SigState s)
_traceTrans :: forall s. Trace s -> StrictSeq (SigState s)
_traceTrans } =
  case StrictSeq (SigState s)
_traceTrans of
    StrictSeq (SigState s)
Empty -> String -> Trace s
forall a. HasCallStack => String -> a
error String
"traceInit was called with a trace without signals"
    SigState s
_ :<| StrictSeq (SigState s)
trans -> Trace s
tr { _traceTrans :: StrictSeq (SigState s)
_traceTrans = StrictSeq (SigState s)
trans }

-- | Retrieve all the signals in the trace paired with the state prior to the
-- application of the signal.
--
-- Note that the last state in the trace will not be returned, since there is
-- no corresponding signal, i.e. the last state is not the pre-state of any
-- signal in the trace.
--
-- Examples
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> preStatesAndSignals NewestFirst tr0
-- []
--
-- >>> preStatesAndSignals OldestFirst tr0
-- []
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> preStatesAndSignals OldestFirst tr0123
-- [(0,"one"),(1,"two"),(2,"three")]
--
-- >>> preStatesAndSignals NewestFirst tr0123
-- [(2,"three"),(1,"two"),(0,"one")]
--
preStatesAndSignals :: TraceOrder -> Trace s -> [(State s, Signal s)]
preStatesAndSignals :: TraceOrder -> Trace s -> [(State s, Signal s)]
preStatesAndSignals TraceOrder
OldestFirst Trace s
tr
  = [State s] -> [Signal s] -> [(State s, Signal s)]
forall a b. [a] -> [b] -> [(a, b)]
zip (TraceOrder -> Trace s -> [State s]
forall s. TraceOrder -> Trace s -> [State s]
traceStates TraceOrder
OldestFirst Trace s
tr) (TraceOrder -> Trace s -> [Signal s]
forall s. TraceOrder -> Trace s -> [Signal s]
traceSignals TraceOrder
OldestFirst Trace s
tr)
preStatesAndSignals TraceOrder
NewestFirst Trace s
tr
  = [(State s, Signal s)] -> [(State s, Signal s)]
forall a. [a] -> [a]
reverse ([(State s, Signal s)] -> [(State s, Signal s)])
-> [(State s, Signal s)] -> [(State s, Signal s)]
forall a b. (a -> b) -> a -> b
$ TraceOrder -> Trace s -> [(State s, Signal s)]
forall s. TraceOrder -> Trace s -> [(State s, Signal s)]
preStatesAndSignals TraceOrder
OldestFirst Trace s
tr

-- | Apply the signals in the list and elaborate a trace with the resulting
-- states.
--
-- If any of the signals cannot be applied, then it is discarded, and the next
-- signal is tried.
--
-- >>> :set -XTypeFamilies
-- >>> :set -XTypeApplications
-- >>> import Control.State.Transition (initialRules, transitionRules, judgmentContext)
-- >>> import Data.Functor.Identity
-- >>> :{
-- data ADDER
-- data AdderPredicateFailure = NoFailuresPossible deriving (Eq, Show)
-- instance STS ADDER where
--   type Environment ADDER = ()
--   type State ADDER = Int
--   type Signal ADDER = Int
--   type PredicateFailure ADDER = AdderPredicateFailure
--   initialRules = [ pure 0 ]
--   transitionRules =
--     [ do
--         TRC ((), st, inc) <- judgmentContext
--         pure $! st + inc
--     ]
-- :}
--
-- >>> runIdentity $ closure @ADDER () 0 [3, 2, 1]
-- Trace {_traceEnv = (), _traceInitState = 0, _traceTrans = StrictSeq {fromStrict = fromList [SigState 6 3,SigState 3 2,SigState 1 1]}}
--
-- >>> runIdentity $ closure @ADDER () 10 [-3, -2, -1]
-- Trace {_traceEnv = (), _traceInitState = 10, _traceTrans = StrictSeq {fromStrict = fromList [SigState 4 (-3),SigState 7 (-2),SigState 9 (-1)]}}
--
closure
  :: forall s m
   . (STS s, m ~ BaseM s)
   => Environment s
   -> State s
   -> [Signal s]
   -- ^ List of signals to apply, where the newest signal comes first.
   -> m (Trace s)
closure :: Environment s -> State s -> [Signal s] -> m (Trace s)
closure Environment s
env State s
st0 [Signal s]
sigs = Environment s -> State s -> [(State s, Signal s)] -> Trace s
forall s.
Environment s -> State s -> [(State s, Signal s)] -> Trace s
mkTrace Environment s
env State s
st0 ([(State s, Signal s)] -> Trace s)
-> m [(State s, Signal s)] -> m (Trace s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State s
-> [Signal s] -> [(State s, Signal s)] -> m [(State s, Signal s)]
loop State s
st0 ([Signal s] -> [Signal s]
forall a. [a] -> [a]
reverse [Signal s]
sigs) []
  where
    loop :: State s
-> [Signal s] -> [(State s, Signal s)] -> m [(State s, Signal s)]
loop State s
_ [] [(State s, Signal s)]
acc = [(State s, Signal s)] -> m [(State s, Signal s)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(State s, Signal s)]
acc
    loop State s
sti (Signal s
sig : [Signal s]
sigs') [(State s, Signal s)]
acc =
      RuleContext 'Transition s
-> m (Either [[PredicateFailure s]] (State s))
forall s (m :: * -> *) (rtype :: RuleType).
(STS s, RuleTypeRep rtype, m ~ BaseM s) =>
RuleContext rtype s -> m (Either [[PredicateFailure s]] (State s))
applySTSTest @s ((Environment s, State s, Signal s) -> TRC s
forall sts. (Environment sts, State sts, Signal sts) -> TRC sts
TRC(Environment s
env, State s
sti, Signal s
sig)) m (Either [[PredicateFailure s]] (State s))
-> (Either [[PredicateFailure s]] (State s)
    -> m [(State s, Signal s)])
-> m [(State s, Signal s)]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left [[PredicateFailure s]]
_ -> State s
-> [Signal s] -> [(State s, Signal s)] -> m [(State s, Signal s)]
loop State s
sti [Signal s]
sigs' [(State s, Signal s)]
acc
        Right State s
sti' -> State s
-> [Signal s] -> [(State s, Signal s)] -> m [(State s, Signal s)]
loop State s
sti' [Signal s]
sigs' ((State s
sti', Signal s
sig) (State s, Signal s)
-> [(State s, Signal s)] -> [(State s, Signal s)]
forall a. a -> [a] -> [a]
: [(State s, Signal s)]
acc)

--------------------------------------------------------------------------------
-- Minimal DSL to specify expectations on traces
--------------------------------------------------------------------------------

-- | Bind the state inside the first argument, and apply the transition
-- function in the @Reader@ environment to that state and given signal,
-- obtaining the resulting state, or an assertion failure if the transition
-- function fails.
(.-)
  :: forall m st sig err
   . ( MonadIO m
     , MonadReader (st -> sig -> Either err st) m
     , Show err
     )
  => m st -> sig -> m st
m st
mSt .- :: m st -> sig -> m st
.- sig
sig = do
  st
st       <- m st
mSt
  st -> sig -> Either err st
validate <- m (st -> sig -> Either err st)
forall r (m :: * -> *). MonadReader r m => m r
ask -- Get the validation function from the environment
  case st -> sig -> Either err st
validate st
st sig
sig of
    Left err
pfs -> IO st -> m st
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO st -> m st) -> (err -> IO st) -> err -> m st
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO st
forall a. HasCallStack => String -> IO a
assertFailure (String -> IO st) -> (err -> String) -> err -> IO st
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> String
forall a. Show a => a -> String
show (err -> m st) -> err -> m st
forall a b. (a -> b) -> a -> b
$ err
pfs
    Right st
st' -> st -> m st
forall (f :: * -> *) a. Applicative f => a -> f a
pure st
st'

-- | Bind the state inside the first argument, and check whether it is equal to
-- the expected state, given in the second argument.
(.->)
  :: forall m st
   . (MonadIO m, Eq st, Show st)
  => m st -> st -> m st
m st
mSt .-> :: m st -> st -> m st
.-> st
stExpected = do
  st
stActual <- m st
mSt
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ st
stActual st -> st -> IO ()
forall a. (Eq a, Show a, HasCallStack) => a -> a -> IO ()
@?= st
stExpected
  st -> m st
forall (m :: * -> *) a. Monad m => a -> m a
return st
stActual

checkTrace
  :: forall s m
   . (STS s, BaseM s ~ m)
  => (forall a. m a -> a)
  -> Environment s
  -> ReaderT (State s -> Signal s -> (Either [[PredicateFailure s]] (State s))) IO (State s)
  -> IO ()
checkTrace :: (forall a. m a -> a)
-> Environment s
-> ReaderT
     (State s -> Signal s -> Either [[PredicateFailure s]] (State s))
     IO
     (State s)
-> IO ()
checkTrace forall a. m a -> a
interp Environment s
env ReaderT
  (State s -> Signal s -> Either [[PredicateFailure s]] (State s))
  IO
  (State s)
act =
  IO (State s) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (State s) -> IO ()) -> IO (State s) -> IO ()
forall a b. (a -> b) -> a -> b
$ ReaderT
  (State s -> Signal s -> Either [[PredicateFailure s]] (State s))
  IO
  (State s)
-> (State s -> Signal s -> Either [[PredicateFailure s]] (State s))
-> IO (State s)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (State s -> Signal s -> Either [[PredicateFailure s]] (State s))
  IO
  (State s)
act (\State s
st Signal s
sig -> m (Either [[PredicateFailure s]] (State s))
-> Either [[PredicateFailure s]] (State s)
forall a. m a -> a
interp (m (Either [[PredicateFailure s]] (State s))
 -> Either [[PredicateFailure s]] (State s))
-> m (Either [[PredicateFailure s]] (State s))
-> Either [[PredicateFailure s]] (State s)
forall a b. (a -> b) -> a -> b
$ RuleContext 'Transition s
-> m (Either [[PredicateFailure s]] (State s))
forall s (m :: * -> *) (rtype :: RuleType).
(STS s, RuleTypeRep rtype, m ~ BaseM s) =>
RuleContext rtype s -> m (Either [[PredicateFailure s]] (State s))
applySTSTest ((Environment s, State s, Signal s) -> TRC s
forall sts. (Environment sts, State sts, Signal sts) -> TRC sts
TRC(Environment s
env, State s
st, Signal s
sig)))

-- | Extract all the values of a given type.
--
-- Examples:
--
-- >>> extractValues "hello" :: [Char]
-- "hello"
--
-- >>> extractValues ("hello", " " ,"world") :: [Char]
-- "hello world"
--
-- >>> extractValues "hello" :: [Int]
-- []
--
-- >>> extractValues ([('a', 0 :: Int), ('b', 1)] :: [(Char, Int)]) :: [Int]
-- [0,1]
--
-- >>> extractValues (["hello"] :: [[Char]], 1, 'z') :: [[Char]]
-- ["hello","ello","llo","lo","o",""]
--
-- >>> extractValues ("hello", 'z') :: [Char]
-- "zhello"
--
extractValues :: forall d a . (Data d, Typeable a) => d -> [a]
extractValues :: d -> [a]
extractValues d
d =  [Maybe a] -> [a]
forall a. [Maybe a] -> [a]
catMaybes ((forall d. Data d => d -> Maybe a) -> d -> [Maybe a]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
gmapQ forall d. Data d => d -> Maybe a
extractValue d
d)
                [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ((forall d. Data d => d -> [a]) -> d -> [[a]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
gmapQ forall d. Data d => d -> [a]
forall d a. (Data d, Typeable a) => d -> [a]
extractValues d
d)
  where
    extractValue :: forall d1 . (Data d1) => d1 -> Maybe a
    extractValue :: d1 -> Maybe a
extractValue d1
d1 = d1 -> Maybe a
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast d1
d1

data SourceSignalTarget a =
  SourceSignalTarget {
    SourceSignalTarget a -> State a
source :: State a
  , SourceSignalTarget a -> State a
target :: State a
  , SourceSignalTarget a -> Signal a
signal :: Signal a
  }

deriving instance (Eq (State a), Eq (Signal a)) => Eq (SourceSignalTarget a)
deriving instance (Show (State a), Show (Signal a)) => Show (SourceSignalTarget a)

-- | Extract triplets of the form [SourceSignalTarget {source = s, signal = sig, target =
-- t)] from a trace. For a valid trace, each source state can reach a target
-- state via the given signal.
--
-- Examples
--
--
-- >>> tr0 = mkTrace True 0 [] :: Trace DUMMY
-- >>> sourceSignalTargets tr0
-- []
--
-- >>> tr0123 = mkTrace True 0 [(3, "three"), (2, "two"), (1, "one")] :: Trace DUMMY
-- >>> sourceSignalTargets tr0123
-- [SourceSignalTarget {source = 0, target = 1, signal = "one"},SourceSignalTarget {source = 1, target = 2, signal = "two"},SourceSignalTarget {source = 2, target = 3, signal = "three"}]
--
sourceSignalTargets :: forall a. Trace a -> [SourceSignalTarget a]
sourceSignalTargets :: Trace a -> [SourceSignalTarget a]
sourceSignalTargets Trace a
trace = (State a -> State a -> Signal a -> SourceSignalTarget a)
-> [State a] -> [State a] -> [Signal a] -> [SourceSignalTarget a]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 State a -> State a -> Signal a -> SourceSignalTarget a
forall a. State a -> State a -> Signal a -> SourceSignalTarget a
SourceSignalTarget [State a]
states ([State a] -> [State a]
forall a. [a] -> [a]
tail [State a]
states) [Signal a]
signals
  where
    states :: [State a]
states = TraceOrder -> Trace a -> [State a]
forall s. TraceOrder -> Trace s -> [State s]
traceStates TraceOrder
OldestFirst Trace a
trace
    signals :: [Signal a]
signals = TraceOrder -> Trace a -> [Signal a]
forall s. TraceOrder -> Trace s -> [Signal s]
traceSignals TraceOrder
OldestFirst Trace a
trace

-- | Apply STS checking assertions.
applySTSTest ::
  forall s m rtype.
  (STS s, RuleTypeRep rtype, m ~ BaseM s) =>
  RuleContext rtype s ->
  m (Either [[PredicateFailure s]] (State s))
applySTSTest :: RuleContext rtype s -> m (Either [[PredicateFailure s]] (State s))
applySTSTest RuleContext rtype s
ctx =
  ApplySTSOpts
-> RuleContext rtype s -> m (State s, [[PredicateFailure s]])
forall s (m :: * -> *) (rtype :: RuleType).
(STS s, RuleTypeRep rtype, m ~ BaseM s) =>
ApplySTSOpts
-> RuleContext rtype s -> m (State s, [[PredicateFailure s]])
applySTSOpts ApplySTSOpts
defaultOpts RuleContext rtype s
ctx m (State s, [[PredicateFailure s]])
-> ((State s, [[PredicateFailure s]])
    -> Either [[PredicateFailure s]] (State s))
-> m (Either [[PredicateFailure s]] (State s))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
    (State s
st, []) -> State s -> Either [[PredicateFailure s]] (State s)
forall a b. b -> Either a b
Right State s
st
    (State s
_, [[PredicateFailure s]]
pfs) -> [[PredicateFailure s]] -> Either [[PredicateFailure s]] (State s)
forall a b. a -> Either a b
Left [[PredicateFailure s]]
pfs
  where
    defaultOpts :: ApplySTSOpts
defaultOpts =
      ApplySTSOpts :: AssertionPolicy -> ValidationPolicy -> ApplySTSOpts
ApplySTSOpts
        { asoAssertions :: AssertionPolicy
asoAssertions = AssertionPolicy
AssertionsAll,
          asoValidation :: ValidationPolicy
asoValidation = ValidationPolicy
ValidateAll
        }