{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ouroboros.Network.Protocol.Handshake
( runHandshakeClient
, runHandshakeServer
, HandshakeArguments (..)
, HandshakeException (..)
) where
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Tracer (Tracer, contramap)
import qualified Data.ByteString.Lazy as BL
import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Term as CBOR
import Network.Mux.Trace
import Network.Mux.Types
import Network.TypedProtocol.Codec
import Ouroboros.Network.Channel
import Ouroboros.Network.Driver.Limits
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Client
import Ouroboros.Network.Protocol.Handshake.Server
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum :: MiniProtocolNum
handshakeProtocolNum = Word16 -> MiniProtocolNum
MiniProtocolNum Word16
0
data HandshakeException a =
HandshakeProtocolLimit ProtocolLimitFailure
| HandshakeProtocolError a
tryHandshake :: forall m a r.
( MonadAsync m
, MonadMask m
)
=> m (Either a r)
-> m (Either (HandshakeException a) r)
tryHandshake :: m (Either a r) -> m (Either (HandshakeException a) r)
tryHandshake m (Either a r)
doHandshake = do
Either ProtocolLimitFailure (Either a r)
mapp <- m (Either a r) -> m (Either ProtocolLimitFailure (Either a r))
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try m (Either a r)
doHandshake
case Either ProtocolLimitFailure (Either a r)
mapp of
Left ProtocolLimitFailure
err ->
Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException a) r
-> m (Either (HandshakeException a) r))
-> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException a -> Either (HandshakeException a) r
forall a b. a -> Either a b
Left (HandshakeException a -> Either (HandshakeException a) r)
-> HandshakeException a -> Either (HandshakeException a) r
forall a b. (a -> b) -> a -> b
$ ProtocolLimitFailure -> HandshakeException a
forall a. ProtocolLimitFailure -> HandshakeException a
HandshakeProtocolLimit ProtocolLimitFailure
err
Right (Left a
err) ->
Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException a) r
-> m (Either (HandshakeException a) r))
-> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall a b. (a -> b) -> a -> b
$ HandshakeException a -> Either (HandshakeException a) r
forall a b. a -> Either a b
Left (HandshakeException a -> Either (HandshakeException a) r)
-> HandshakeException a -> Either (HandshakeException a) r
forall a b. (a -> b) -> a -> b
$ a -> HandshakeException a
forall a. a -> HandshakeException a
HandshakeProtocolError a
err
Right (Right r
r) -> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (HandshakeException a) r
-> m (Either (HandshakeException a) r))
-> Either (HandshakeException a) r
-> m (Either (HandshakeException a) r)
forall a b. (a -> b) -> a -> b
$ r -> Either (HandshakeException a) r
forall a b. b -> Either a b
Right r
r
data HandshakeArguments connectionId vNumber vData m application = HandshakeArguments {
HandshakeArguments connectionId vNumber vData m application
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer m (WithMuxBearer connectionId
(TraceSendRecv (Handshake vNumber CBOR.Term))),
HandshakeArguments connectionId vNumber vData m application
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
:: Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m BL.ByteString,
HandshakeArguments connectionId vNumber vData m application
-> VersionDataCodec Term vNumber vData
haVersionDataCodec
:: VersionDataCodec CBOR.Term vNumber vData,
HandshakeArguments connectionId vNumber vData m application
-> Versions vNumber vData application
haVersions :: Versions vNumber vData application
}
runHandshakeClient
:: ( MonadAsync m
, MonadFork m
, MonadMonotonicTime m
, MonadTimer m
, MonadMask m
, MonadThrow (STM m)
, Ord vNumber
)
=> MuxBearer m
-> connectionId
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either (HandshakeException (HandshakeClientProtocolError vNumber))
(application, vNumber, vData))
runHandshakeClient :: MuxBearer m
-> connectionId
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either
(HandshakeException (HandshakeClientProtocolError vNumber))
(application, vNumber, vData))
runHandshakeClient MuxBearer m
bearer
connectionId
connectionId
vData -> vData -> Accept vData
acceptVersion
HandshakeArguments {
Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
Versions vNumber vData application
haVersions :: Versions vNumber vData application
haVersions :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Versions vNumber vData application
haVersions
} =
m (Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData))
-> m (Either
(HandshakeException (HandshakeClientProtocolError vNumber))
(application, vNumber, vData))
forall (m :: * -> *) a r.
(MonadAsync m, MonadMask m) =>
m (Either a r) -> m (Either (HandshakeException a) r)
tryHandshake
((Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData),
Maybe ByteString)
-> Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData),
Maybe ByteString)
-> Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData))
-> m (Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData),
Maybe ByteString)
-> m (Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
(Handshake vNumber Term)
'AsClient
'StPropose
m
(Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData))
-> m (Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData),
Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
MonadMonotonicTime m, MonadTimer m,
forall (st' :: ps). Show (ClientHasAgency st'),
forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
(connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
ProtocolTimeLimits (Handshake vNumber Term)
forall k (vNumber :: k).
ProtocolTimeLimits (Handshake vNumber Term)
timeLimitsHandshake
(Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
InitiatorDir))
(VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
(Handshake vNumber Term)
'AsClient
'StPropose
m
(Either
(HandshakeClientProtocolError vNumber)
(application, vNumber, vData))
forall vNumber vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
(Handshake vNumber Term)
'AsClient
'StPropose
m
(Either (HandshakeClientProtocolError vNumber) (r, vNumber, vData))
handshakeClientPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
acceptVersion Versions vNumber vData application
haVersions))
runHandshakeServer
:: ( MonadAsync m
, MonadFork m
, MonadMonotonicTime m
, MonadTimer m
, MonadMask m
, MonadThrow (STM m)
, Ord vNumber
)
=> MuxBearer m
-> connectionId
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either
(HandshakeException (RefuseReason vNumber))
(application, vNumber, vData))
runHandshakeServer :: MuxBearer m
-> connectionId
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either
(HandshakeException (RefuseReason vNumber))
(application, vNumber, vData))
runHandshakeServer MuxBearer m
bearer
connectionId
connectionId
vData -> vData -> Accept vData
acceptVersion
HandshakeArguments {
Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer,
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec,
VersionDataCodec Term vNumber vData
haVersionDataCodec :: VersionDataCodec Term vNumber vData
haVersionDataCodec :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> VersionDataCodec Term vNumber vData
haVersionDataCodec,
Versions vNumber vData application
haVersions :: Versions vNumber vData application
haVersions :: forall connectionId vNumber vData (m :: * -> *) application.
HandshakeArguments connectionId vNumber vData m application
-> Versions vNumber vData application
haVersions
} =
m (Either (RefuseReason vNumber) (application, vNumber, vData))
-> m (Either
(HandshakeException (RefuseReason vNumber))
(application, vNumber, vData))
forall (m :: * -> *) a r.
(MonadAsync m, MonadMask m) =>
m (Either a r) -> m (Either (HandshakeException a) r)
tryHandshake
((Either (RefuseReason vNumber) (application, vNumber, vData),
Maybe ByteString)
-> Either (RefuseReason vNumber) (application, vNumber, vData)
forall a b. (a, b) -> a
fst ((Either (RefuseReason vNumber) (application, vNumber, vData),
Maybe ByteString)
-> Either (RefuseReason vNumber) (application, vNumber, vData))
-> m (Either (RefuseReason vNumber) (application, vNumber, vData),
Maybe ByteString)
-> m (Either (RefuseReason vNumber) (application, vNumber, vData))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Tracer m (TraceSendRecv (Handshake vNumber Term))
-> Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
-> ProtocolSizeLimits (Handshake vNumber Term) ByteString
-> ProtocolTimeLimits (Handshake vNumber Term)
-> Channel m ByteString
-> Peer
(Handshake vNumber Term)
'AsServer
'StPropose
m
(Either (RefuseReason vNumber) (application, vNumber, vData))
-> m (Either (RefuseReason vNumber) (application, vNumber, vData),
Maybe ByteString)
forall ps (st :: ps) (pr :: PeerRole) failure bytes (m :: * -> *)
a.
(MonadAsync m, MonadFork m, MonadMask m, MonadThrow (STM m),
MonadMonotonicTime m, MonadTimer m,
forall (st' :: ps). Show (ClientHasAgency st'),
forall (st' :: ps). Show (ServerHasAgency st'), ShowProxy ps,
Show failure) =>
Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> ProtocolSizeLimits ps bytes
-> ProtocolTimeLimits ps
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runPeerWithLimits
(connectionId
-> TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term))
forall peerid a. peerid -> a -> WithMuxBearer peerid a
WithMuxBearer connectionId
connectionId (TraceSendRecv (Handshake vNumber Term)
-> WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
-> Tracer m (TraceSendRecv (Handshake vNumber Term))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
`contramap` Tracer
m
(WithMuxBearer
connectionId (TraceSendRecv (Handshake vNumber Term)))
haHandshakeTracer)
Codec (Handshake vNumber Term) DeserialiseFailure m ByteString
haHandshakeCodec
ProtocolSizeLimits (Handshake vNumber Term) ByteString
forall k (vNumber :: k).
ProtocolSizeLimits (Handshake vNumber Term) ByteString
byteLimitsHandshake
ProtocolTimeLimits (Handshake vNumber Term)
forall k (vNumber :: k).
ProtocolTimeLimits (Handshake vNumber Term)
timeLimitsHandshake
(Channel m -> Channel m ByteString
forall (m :: * -> *). Channel m -> Channel m ByteString
fromChannel (MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
forall (m :: * -> *).
Functor m =>
MuxBearer m -> MiniProtocolNum -> MiniProtocolDir -> Channel m
muxBearerAsChannel MuxBearer m
bearer MiniProtocolNum
handshakeProtocolNum MiniProtocolDir
ResponderDir))
(VersionDataCodec Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData application
-> Peer
(Handshake vNumber Term)
'AsServer
'StPropose
m
(Either (RefuseReason vNumber) (application, vNumber, vData))
forall vNumber vParams vData r (m :: * -> *).
Ord vNumber =>
VersionDataCodec vParams vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer
(Handshake vNumber vParams)
'AsServer
'StPropose
m
(Either (RefuseReason vNumber) (r, vNumber, vData))
handshakeServerPeer VersionDataCodec Term vNumber vData
haVersionDataCodec vData -> vData -> Accept vData
acceptVersion Versions vNumber vData application
haVersions))