{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | API for running 'Handshake' protocol.
--
module Ouroboros.Network.Protocol.Handshake
  ( runHandshakeClient
  , runHandshakeServer
  , HandshakeArguments (..)
  , HandshakeException (..)
  ) where

import           Control.Monad.Class.MonadAsync
import           Control.Monad.Class.MonadFork
import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadThrow
import           Control.Monad.Class.MonadTime
import           Control.Monad.Class.MonadTimer

import           Control.Tracer (Tracer, contramap)
import qualified Data.ByteString.Lazy as BL
import qualified Codec.CBOR.Read     as CBOR
import qualified Codec.CBOR.Term     as CBOR

import           Network.Mux.Trace
import           Network.Mux.Types
import           Network.TypedProtocol.Codec

import           Ouroboros.Network.Channel
import           Ouroboros.Network.Driver.Limits

import           Ouroboros.Network.Protocol.Handshake.Type
import           Ouroboros.Network.Protocol.Handshake.Version
import           Ouroboros.Network.Protocol.Handshake.Codec
import           Ouroboros.Network.Protocol.Handshake.Client
import           Ouroboros.Network.Protocol.Handshake.Server


-- | The handshake protocol number.
--
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum = Word16 -> MiniProtocolNum
MiniProtocolNum Word16
0

-- | Wrapper around initiator and responder errors experienced by tryHandshake.
--
data HandshakeException a =
    HandshakeProtocolLimit ProtocolLimitFailure
  | HandshakeProtocolError a


-- | Try to complete either initiator or responder side of the Handshake protocol
-- within `handshakeTimeout` seconds.
--
tryHandshake :: forall m a r.
                ( MonadAsync m
                , MonadMask m
                )
             => m (Either a r)
             -> m (Either (HandshakeException a) r)
tryHandshake :: m (Either a r) -> m (Either (HandshakeException a) r)
tryHandshake m (Either a r)
doHandshake = do
    Either ProtocolLimitFailure (Either a r)
mapp <- m (Either a r) -> m (Either ProtocolLimitFailure (Either a r))
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try m (Either a r)
doHandshake
    case Either ProtocolLimitFailure (Either a r)
mapp of
      Left ProtocolLimitFailure
err ->
          Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException a) r
 -> m (Either (HandshakeException a) r))
-> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException a -> Either (HandshakeException a) r
forall a b. a -> Either a b
Left (HandshakeException a -> Either (HandshakeException a) r)
-> HandshakeException a -> Either (HandshakeException a) r
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> HandshakeException a
forall a. ProtocolLimitFailure -> HandshakeException a
HandshakeProtocolLimit ProtocolLimitFailure
err
      Right (Left a
err) ->
          Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException a) r
 -> m (Either (HandshakeException a) r))
-> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException a -> Either (HandshakeException a) r
forall a b. a -> Either a b
Left (HandshakeException a -> Either (HandshakeException a) r)
-> HandshakeException a -> Either (HandshakeException a) r
forall a b. (a -> b) -> a -> b
$ a -> HandshakeException a
forall a. a -> HandshakeException a
HandshakeProtocolError a
err
      Right (Right r
r) -> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException a) r
 -> m (Either (HandshakeException a) r))
-> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall a b. (a -> b) -> a -> b
$ r -> Either (HandshakeException a) r
forall a b. b -> Either a b
Right r
r


--
-- Record arguemnts
--

-- | Common arguments for both 'Handshake' client & server.
--
data HandshakeArguments connectionId vNumber vData m application = HandshakeArguments {
      -- | 'Handshake' tracer
      --
      HandshakeArguments connectionId vNumber vData m application
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer m (WithMuxBearer connectionId
                                     (TraceSendRecv (Handshake vNumber CBOR.Term))),
      -- | Codec for protocol messages.
      --
      HandshakeArguments connectionId vNumber vData m application
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
        ::  Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m BL.ByteString,

      -- | A codec for protocol parameters.
      --
      HandshakeArguments connectionId vNumber vData m application
-> VersionDataCodec Term vNumber vData
haVersionDataCodec
        ::  VersionDataCodec CBOR.Term vNumber vData,

      -- | versioned application aggreed upon with the 'Handshake' protocol.
      HandshakeArguments connectionId vNumber vData m application
-> Versions vNumber vData application
haVersions :: Versions vNumber vData application
    }


-- | Run client side of the 'Handshake' protocol
--
runHandshakeClient
    :: ( MonadAsync m
       , MonadFork m
       , MonadMonotonicTime m
       , MonadTimer m
       , MonadMask m
       , MonadThrow (STM m)
       , Ord vNumber
       )
    => MuxBearer m
    -> connectionId
    -> (vData -> vData -> Accept vData)
    -> HandshakeArguments connectionId vNumber vData m application
    -> m (Either (HandshakeException (HandshakeClientProtocolError vNumber))
                 (application, vNumber, vData))
runHandshakeClient :: MuxBearer m
-> connectionId
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either
        (HandshakeException (HandshakeClientProtocolError vNumber))
        (application, vNumber, vData))
runHandshakeClient MuxBearer m
bearer
                   connectionId
connectionId
                   vData -> vData -> Accept vData
acceptVersion
                   HandshakeArguments {
                     Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
                     Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
                     VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
                     Versions vNumber vData application
haVersions :: Versions vNumber vData application
haVersions :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Versions vNumber vData application
haVersions
                   } =
    m (Either
     (HandshakeClientProtocolError vNumber)
     (application, vNumber, vData))
-> m (Either
        (HandshakeException (HandshakeClientProtocolError vNumber))
        (application, vNumber, vData))
forall (m :: * -> *) a r.
(MonadAsync m, MonadMask m) =>
m (Either a r) -> m (Either (HandshakeException a) r)
tryHandshake
      ((Either
   (HandshakeClientProtocolError vNumber)
   (application, vNumber, vData),
 Maybe ByteString)
-> Either
     (HandshakeClientProtocolError vNumber)
     (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either
    (HandshakeClientProtocolError vNumber)
    (application, vNumber, vData),
  Maybe ByteString)
 -> Either
      (HandshakeClientProtocolError vNumber)
      (application, vNumber, vData))
-> m (Either
        (HandshakeClientProtocolError vNumber)
        (application, vNumber, vData),
      Maybe ByteString)
-> m (Either
        (HandshakeClientProtocolError vNumber)
        (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
     (Handshake vNumber Term)
     'AsClient
     'StPropose
     m
     (Either
        (HandshakeClientProtocolError vNumber)
        (application, vNumber, vData))
-> m (Either
        (HandshakeClientProtocolError vNumber)
        (application, vNumber, vData),
      Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadMonotonicTime m, MonadTimer m,
 forall (st' :: ps). Show (ClientHasAgency st'),
 forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
 Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
          (connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
 -> WithMuxBearer
      connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
          Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
          ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
          ProtocolTimeLimits (Handshake vNumber Term)
forall k (vNumber :: k).
ProtocolTimeLimits (Handshake vNumber Term)
timeLimitsHandshake
          (Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
InitiatorDir))
          (VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
     (Handshake vNumber Term)
     'AsClient
     'StPropose
     m
     (Either
        (HandshakeClientProtocolError vNumber)
        (application, vNumber, vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
     (Handshake vNumber Term)
     'AsClient
     'StPropose
     m
     (Either (HandshakeClientProtocolError vNumber) (r, vNumber, vData))
handshakeClientPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
acceptVersion Versions vNumber vData application
haVersions))


-- | Run server side of the 'Handshake' protocol.
--
runHandshakeServer
    :: ( MonadAsync m
       , MonadFork m
       , MonadMonotonicTime m
       , MonadTimer m
       , MonadMask m
       , MonadThrow (STM m)
       , Ord vNumber
       )
    => MuxBearer m
    -> connectionId
    -> (vData -> vData -> Accept vData)
    -> HandshakeArguments connectionId vNumber vData m application
    -> m (Either
           (HandshakeException (RefuseReason vNumber))
           (application, vNumber, vData))
runHandshakeServer :: MuxBearer m
-> connectionId
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either
        (HandshakeException (RefuseReason vNumber))
        (application, vNumber, vData))
runHandshakeServer MuxBearer m
bearer
                   connectionId
connectionId
                   vData -> vData -> Accept vData
acceptVersion
                   HandshakeArguments {
                     Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
                     Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
                     VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
                     Versions vNumber vData application
haVersions :: Versions vNumber vData application
haVersions :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Versions vNumber vData application
haVersions
                   } =
    m (Either (RefuseReason vNumber) (application, vNumber, vData))
-> m (Either
        (HandshakeException (RefuseReason vNumber))
        (application, vNumber, vData))
forall (m :: * -> *) a r.
(MonadAsync m, MonadMask m) =>
m (Either a r) -> m (Either (HandshakeException a) r)
tryHandshake
      ((Either (RefuseReason vNumber) (application, vNumber, vData),
 Maybe ByteString)
-> Either (RefuseReason vNumber) (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either (RefuseReason vNumber) (application, vNumber, vData),
  Maybe ByteString)
 -> Either (RefuseReason vNumber) (application, vNumber, vData))
-> m (Either (RefuseReason vNumber) (application, vNumber, vData),
      Maybe ByteString)
-> m (Either (RefuseReason vNumber) (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
     (Handshake vNumber Term)
     'AsServer
     'StPropose
     m
     (Either (RefuseReason vNumber) (application, vNumber, vData))
-> m (Either (RefuseReason vNumber) (application, vNumber, vData),
      Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
       a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
 MonadMonotonicTime m, MonadTimer m,
 forall (st' :: ps). Show (ClientHasAgency st'),
 forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
 Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
          (connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
 -> WithMuxBearer
      connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
     m
     (WithMuxBearer
        connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
  m
  (WithMuxBearer
     connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
          Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
          ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
          ProtocolTimeLimits (Handshake vNumber Term)
forall k (vNumber :: k).
ProtocolTimeLimits (Handshake vNumber Term)
timeLimitsHandshake
          (Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
ResponderDir))
          (VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
     (Handshake vNumber Term)
     'AsServer
     'StPropose
     m
     (Either (RefuseReason vNumber) (application, vNumber, vData))
forall vNumber vParams vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec vParams vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
     (Handshake vNumber vParams)
     'AsServer
     'StPropose
     m
     (Either (RefuseReason vNumber) (r, vNumber, vData))
handshakeServerPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
acceptVersion Versions vNumber vData application
haVersions))