{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE ScopedTypeVariables   #-}

module Ouroboros.Network.Mux
  ( MuxMode (..)
  , OuroborosApplication (..)
  , MiniProtocol (..)
  , MiniProtocolNum (..)
  , MiniProtocolLimits (..)
  , RunMiniProtocol (..)
  , MuxPeer (..)
  , toApplication
  , ControlMessage (..)
  , ControlMessageSTM
  , continueForever
  , timeoutWithControlMessage

    -- * Re-exports
    -- | from "Network.Mux"
  , MuxError(..)
  , MuxErrorType(..)
  , HasInitiator
  , HasResponder
  ) where

import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadThrow
import           Control.Tracer (Tracer)

import           Data.Void (Void)
import qualified Data.ByteString.Lazy as LBS

import           Network.TypedProtocol.Core
import           Network.TypedProtocol.Pipelined

import qualified Network.Mux.Compat as Mux
import           Network.Mux
                   ( MuxMode(..), HasInitiator, HasResponder
                   , MiniProtocolNum, MiniProtocolLimits(..)
                   , MuxError(..), MuxErrorType(..) )

import           Ouroboros.Network.Channel
import           Ouroboros.Network.ConnectionId
import           Ouroboros.Network.Codec
import           Ouroboros.Network.Driver
import           Ouroboros.Network.Util.ShowProxy (ShowProxy)


-- | Control signal sent to a mini-protocol.  expected to exit, on 'Continue' it
-- should continue its operation
--
data ControlMessage =
    -- | Continue operation.
      Continue

    -- | Hold on, e.g. do not sent messages until resumed.  This is not used for
    -- any hot protocol.
    --
    | Quiesce

    -- | The client is expected to terminate as soon as possible.
    --
    | Terminate
  deriving (ControlMessage -> ControlMessage -> Bool
(ControlMessage -> ControlMessage -> Bool)
-> (ControlMessage -> ControlMessage -> Bool) -> Eq ControlMessage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ControlMessage -> ControlMessage -> Bool
$c/= :: ControlMessage -> ControlMessage -> Bool
== :: ControlMessage -> ControlMessage -> Bool
$c== :: ControlMessage -> ControlMessage -> Bool
Eq, Int -> ControlMessage -> ShowS
[ControlMessage] -> ShowS
ControlMessage -> String
(Int -> ControlMessage -> ShowS)
-> (ControlMessage -> String)
-> ([ControlMessage] -> ShowS)
-> Show ControlMessage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ControlMessage] -> ShowS
$cshowList :: [ControlMessage] -> ShowS
show :: ControlMessage -> String
$cshow :: ControlMessage -> String
showsPrec :: Int -> ControlMessage -> ShowS
$cshowsPrec :: Int -> ControlMessage -> ShowS
Show)

-- |  'ControlMessageSTM' should depend on `muxMode` (we only need to shedule
-- stop for intiator side).  This is not done only because this would break
-- tests, but once the old api is removed it should be possible.
--
type ControlMessageSTM m = STM m ControlMessage

continueForever :: Applicative (STM m)
          => proxy m
          -> ControlMessageSTM m
continueForever :: proxy m -> ControlMessageSTM m
continueForever proxy m
_ = ControlMessage -> ControlMessageSTM m
forall (f :: * -> *) a. Applicative f => a -> f a
pure ControlMessage
Continue


-- | First to finish synchronisation between 'Terminate' state of
-- 'ControlMessage' and an stm action.
--
-- This should return @STM m (Maybe a)@ but 'STM' is a non-injective type
-- family, and we would need to pass @Proxy m@ to fix an ambiuous type (or use
-- 'AllowAmbiguousTypes' extension).
--
timeoutWithControlMessage :: MonadSTM m
                          => ControlMessageSTM m
                          -> STM m a
                          -> m (Maybe a)
timeoutWithControlMessage :: ControlMessageSTM m -> STM m a -> m (Maybe a)
timeoutWithControlMessage ControlMessageSTM m
controlMessageSTM STM m a
stm =
    STM m (Maybe a) -> m (Maybe a)
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM m (Maybe a) -> m (Maybe a)) -> STM m (Maybe a) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$
      do
        ControlMessage
cntrlMsg <- ControlMessageSTM m
controlMessageSTM
        case ControlMessage
cntrlMsg of
          ControlMessage
Terminate -> Maybe a -> STM m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
          ControlMessage
Continue  -> STM m (Maybe a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry
          ControlMessage
Quiesce   -> STM m (Maybe a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a
retry
      STM m (Maybe a) -> STM m (Maybe a) -> STM m (Maybe a)
forall (stm :: * -> *) a. MonadSTMTx stm => stm a -> stm a -> stm a
`orElse` (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> STM m a -> STM m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m a
stm)


-- |  Like 'MuxApplication' but using a 'MuxPeer' rather than a raw
-- @Channel -> m a@ action.
--
newtype OuroborosApplication (mode :: MuxMode) addr bytes m a b =
        OuroborosApplication
          (ConnectionId addr -> ControlMessageSTM m -> [MiniProtocol mode bytes m a b])

data MiniProtocol (mode :: MuxMode) bytes m a b =
     MiniProtocol {
       MiniProtocol mode bytes m a b -> MiniProtocolNum
miniProtocolNum    :: !MiniProtocolNum,
       MiniProtocol mode bytes m a b -> MiniProtocolLimits
miniProtocolLimits :: !MiniProtocolLimits,
       MiniProtocol mode bytes m a b -> RunMiniProtocol mode bytes m a b
miniProtocolRun    :: !(RunMiniProtocol mode bytes m a b)
     }

data RunMiniProtocol (mode :: MuxMode) bytes m a b where
     InitiatorProtocolOnly
       :: MuxPeer bytes m a
       -> RunMiniProtocol InitiatorMode bytes m a Void

     ResponderProtocolOnly
       :: MuxPeer bytes m b
       -> RunMiniProtocol ResponderMode bytes m Void b

     InitiatorAndResponderProtocol
       :: MuxPeer bytes m a
       -> MuxPeer bytes m b
       -> RunMiniProtocol InitiatorResponderMode bytes m a b

data MuxPeer bytes m a where
    MuxPeer :: forall (pr :: PeerRole) ps (st :: ps) failure bytes m a.
               ( Show failure
               , forall (st' :: ps). Show (ClientHasAgency st')
               , forall (st' :: ps). Show (ServerHasAgency st')
               , ShowProxy ps
               )
            => Tracer m (TraceSendRecv ps)
            -> Codec ps failure m bytes
            -> Peer ps pr st m a
            -> MuxPeer bytes m a

    MuxPeerPipelined
             :: forall (pr :: PeerRole) ps (st :: ps) failure bytes m a.
               ( Show failure
               , forall (st' :: ps). Show (ClientHasAgency st')
               , forall (st' :: ps). Show (ServerHasAgency st')
               , ShowProxy ps
               )
            => Tracer m (TraceSendRecv ps)
            -> Codec ps failure m bytes
            -> PeerPipelined ps pr st m a
            -> MuxPeer bytes m a

    MuxPeerRaw
           :: (Channel m bytes -> m (a, Maybe bytes))
           -> MuxPeer bytes m a

toApplication :: (MonadCatch m, MonadAsync m)
              => ConnectionId addr
              -> ControlMessageSTM m
              -> OuroborosApplication mode addr LBS.ByteString m a b
              -> Mux.MuxApplication mode m a b
toApplication :: ConnectionId addr
-> ControlMessageSTM m
-> OuroborosApplication mode addr ByteString m a b
-> MuxApplication mode m a b
toApplication ConnectionId addr
connectionId ControlMessageSTM m
controlMessageSTM (OuroborosApplication ConnectionId addr
-> ControlMessageSTM m -> [MiniProtocol mode ByteString m a b]
ptcls) =
  [MuxMiniProtocol mode m a b] -> MuxApplication mode m a b
forall (mode :: MuxMode) (m :: * -> *) a b.
[MuxMiniProtocol mode m a b] -> MuxApplication mode m a b
Mux.MuxApplication
    [ MuxMiniProtocol :: forall (mode :: MuxMode) (m :: * -> *) a b.
MiniProtocolNum
-> MiniProtocolLimits
-> RunMiniProtocol mode m a b
-> MuxMiniProtocol mode m a b
Mux.MuxMiniProtocol {
        miniProtocolNum :: MiniProtocolNum
Mux.miniProtocolNum    = MiniProtocol mode ByteString m a b -> MiniProtocolNum
forall (mode :: MuxMode) bytes (m :: * -> *) a b.
MiniProtocol mode bytes m a b -> MiniProtocolNum
miniProtocolNum MiniProtocol mode ByteString m a b
ptcl,
        miniProtocolLimits :: MiniProtocolLimits
Mux.miniProtocolLimits = MiniProtocol mode ByteString m a b -> MiniProtocolLimits
forall (mode :: MuxMode) bytes (m :: * -> *) a b.
MiniProtocol mode bytes m a b -> MiniProtocolLimits
miniProtocolLimits MiniProtocol mode ByteString m a b
ptcl,
        miniProtocolRun :: RunMiniProtocol mode m a b
Mux.miniProtocolRun    = RunMiniProtocol mode ByteString m a b -> RunMiniProtocol mode m a b
forall (mode :: MuxMode) (m :: * -> *) a b.
(MonadCatch m, MonadAsync m) =>
RunMiniProtocol mode ByteString m a b -> RunMiniProtocol mode m a b
toMuxRunMiniProtocol (MiniProtocol mode ByteString m a b
-> RunMiniProtocol mode ByteString m a b
forall (mode :: MuxMode) bytes (m :: * -> *) a b.
MiniProtocol mode bytes m a b -> RunMiniProtocol mode bytes m a b
miniProtocolRun MiniProtocol mode ByteString m a b
ptcl)
      }
    | MiniProtocol mode ByteString m a b
ptcl <- ConnectionId addr
-> ControlMessageSTM m -> [MiniProtocol mode ByteString m a b]
ptcls ConnectionId addr
connectionId ControlMessageSTM m
controlMessageSTM ]

toMuxRunMiniProtocol :: forall mode m a b.
                        (MonadCatch m, MonadAsync m)
                     => RunMiniProtocol mode LBS.ByteString m a b
                     -> Mux.RunMiniProtocol mode m a b
toMuxRunMiniProtocol :: RunMiniProtocol mode ByteString m a b -> RunMiniProtocol mode m a b
toMuxRunMiniProtocol (InitiatorProtocolOnly MuxPeer ByteString m a
i) =
  (Channel m -> m (a, Maybe ByteString))
-> RunMiniProtocol 'InitiatorMode m a Void
forall (m :: * -> *) a.
(Channel m -> m (a, Maybe ByteString))
-> RunMiniProtocol 'InitiatorMode m a Void
Mux.InitiatorProtocolOnly (MuxPeer ByteString m a
-> Channel m ByteString -> m (a, Maybe ByteString)
forall (m :: * -> *) bytes a.
(MonadCatch m, MonadAsync m) =>
MuxPeer bytes m a -> Channel m bytes -> m (a, Maybe bytes)
runMuxPeer MuxPeer ByteString m a
i (Channel m ByteString -> m (a, Maybe ByteString))
-> (Channel m -> Channel m ByteString)
-> Channel m
-> m (a, Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel)
toMuxRunMiniProtocol (ResponderProtocolOnly MuxPeer ByteString m b
r) =
  (Channel m -> m (b, Maybe ByteString))
-> RunMiniProtocol 'ResponderMode m Void b
forall (m :: * -> *) b.
(Channel m -> m (b, Maybe ByteString))
-> RunMiniProtocol 'ResponderMode m Void b
Mux.ResponderProtocolOnly (MuxPeer ByteString m b
-> Channel m ByteString -> m (b, Maybe ByteString)
forall (m :: * -> *) bytes a.
(MonadCatch m, MonadAsync m) =>
MuxPeer bytes m a -> Channel m bytes -> m (a, Maybe bytes)
runMuxPeer MuxPeer ByteString m b
r (Channel m ByteString -> m (b, Maybe ByteString))
-> (Channel m -> Channel m ByteString)
-> Channel m
-> m (b, Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel)
toMuxRunMiniProtocol (InitiatorAndResponderProtocol MuxPeer ByteString m a
i MuxPeer ByteString m b
r) =
  (Channel m -> m (a, Maybe ByteString))
-> (Channel m -> m (b, Maybe ByteString))
-> RunMiniProtocol 'InitiatorResponderMode m a b
forall (m :: * -> *) a b.
(Channel m -> m (a, Maybe ByteString))
-> (Channel m -> m (b, Maybe ByteString))
-> RunMiniProtocol 'InitiatorResponderMode m a b
Mux.InitiatorAndResponderProtocol (MuxPeer ByteString m a
-> Channel m ByteString -> m (a, Maybe ByteString)
forall (m :: * -> *) bytes a.
(MonadCatch m, MonadAsync m) =>
MuxPeer bytes m a -> Channel m bytes -> m (a, Maybe bytes)
runMuxPeer MuxPeer ByteString m a
i (Channel m ByteString -> m (a, Maybe ByteString))
-> (Channel m -> Channel m ByteString)
-> Channel m
-> m (a, Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel)
                                    (MuxPeer ByteString m b
-> Channel m ByteString -> m (b, Maybe ByteString)
forall (m :: * -> *) bytes a.
(MonadCatch m, MonadAsync m) =>
MuxPeer bytes m a -> Channel m bytes -> m (a, Maybe bytes)
runMuxPeer MuxPeer ByteString m b
r (Channel m ByteString -> m (b, Maybe ByteString))
-> (Channel m -> Channel m ByteString)
-> Channel m
-> m (b, Maybe ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel)

-- |
-- Run a @'MuxPeer'@ using either @'runPeer'@ or @'runPipelinedPeer'@.
--
runMuxPeer
  :: ( MonadCatch m
     , MonadAsync m
     )
  => MuxPeer bytes m a
  -> Channel m bytes
  -> m (a, Maybe bytes)
runMuxPeer :: MuxPeer bytes m a -> Channel m bytes -> m (a, Maybe bytes)
runMuxPeer (MuxPeer Tracer m (TraceSendRecv ps)
tracer Codec ps failure m bytes
codec Peer ps pr st m a
peer) Channel m bytes
channel =
    Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadThrow m, Show failure,
 forall (st' :: ps). Show (ClientHasAgency st'),
 forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeer Tracer m (TraceSendRecv ps)
tracer Codec ps failure m bytes
codec Channel m bytes
channel Peer ps pr st m a
peer

runMuxPeer (MuxPeerPipelined Tracer m (TraceSendRecv ps)
tracer Codec ps failure m bytes
codec PeerPipelined ps pr st m a
peer) Channel m bytes
channel =
    Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> PeerPipelined ps pr st m a
-> m (a, Maybe bytes)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadThrow m, Show failure,
 forall (st' :: ps). Show (ClientHasAgency st'),
 forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> PeerPipelined ps pr st m a
-> m (a, Maybe bytes)
runPipelinedPeer Tracer m (TraceSendRecv ps)
tracer Codec ps failure m bytes
codec Channel m bytes
channel PeerPipelined ps pr st m a
peer

runMuxPeer (MuxPeerRaw Channel m bytes -> m (a, Maybe bytes)
action) Channel m bytes
channel =
    Channel m bytes -> m (a, Maybe bytes)
action Channel m bytes
channel