{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.Codec
  ( module Network.TypedProtocol.Codec
  , DeserialiseFailure
  , mkCodecCborLazyBS
  , mkCodecCborStrictBS
  ) where

import           Control.Monad.ST
import           Control.Monad.Class.MonadST (MonadST (..))

import qualified Codec.CBOR.Decoding as CBOR (Decoder)
import qualified Codec.CBOR.Encoding as CBOR (Encoding)
import qualified Codec.CBOR.Read  as CBOR
import qualified Codec.CBOR.Write as CBOR
import qualified Data.ByteString.Builder as BS
import qualified Data.ByteString.Builder.Extra as BS
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.Lazy.Internal as LBS (smallChunkSize)

import           Network.TypedProtocol.Core
import           Network.TypedProtocol.Codec


type DeserialiseFailure = CBOR.DeserialiseFailure

-- | Construct a 'Codec' for a CBOR based serialisation format, using strict
-- 'BS.ByteString's.
--
-- This is an adaptor between the @cborg@ library and the 'Codec' abstraction.
--
-- It takes encode and decode functions for the protocol messages that use the
-- CBOR library encoder and decoder.
--
-- Note that this is /less/ efficient than the 'mkCodecCborLazyBS' variant
-- because it has to copy and concatenate the result of the encoder (which
-- natively produces chunks).
--
mkCodecCborStrictBS
  :: forall ps m. MonadST m

  => (forall (pr :: PeerRole) (st :: ps) (st' :: ps).
             PeerHasAgency pr st
          -> Message ps st st' -> CBOR.Encoding)

  -> (forall (pr :: PeerRole) (st :: ps) s.
             PeerHasAgency pr st
          -> CBOR.Decoder s (SomeMessage st))

  -> Codec ps DeserialiseFailure m BS.ByteString
mkCodecCborStrictBS :: (forall (pr :: PeerRole) (st :: ps) (st' :: ps).
 PeerHasAgency pr st -> Message ps st st' -> Encoding)
-> (forall (pr :: PeerRole) (st :: ps) s.
    PeerHasAgency pr st -> Decoder s (SomeMessage st))
-> Codec ps DeserialiseFailure m ByteString
mkCodecCborStrictBS forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st -> Message ps st st' -> Encoding
cborMsgEncode forall (pr :: PeerRole) (st :: ps) s.
PeerHasAgency pr st -> Decoder s (SomeMessage st)
cborMsgDecode =
    Codec :: forall ps failure (m :: * -> *) bytes.
(forall (pr :: PeerRole) (st :: ps) (st' :: ps).
 PeerHasAgency pr st -> Message ps st st' -> bytes)
-> (forall (pr :: PeerRole) (st :: ps).
    PeerHasAgency pr st
    -> m (DecodeStep bytes failure m (SomeMessage st)))
-> Codec ps failure m bytes
Codec {
      encode :: forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st -> Message ps st st' -> ByteString
encode = \PeerHasAgency pr st
stok Message ps st st'
msg -> (Message ps st st' -> Encoding) -> Message ps st st' -> ByteString
forall a. (a -> Encoding) -> a -> ByteString
convertCborEncoder (PeerHasAgency pr st -> Message ps st st' -> Encoding
forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st -> Message ps st st' -> Encoding
cborMsgEncode PeerHasAgency pr st
stok) Message ps st st'
msg,
      decode :: forall (pr :: PeerRole) (st :: ps).
PeerHasAgency pr st
-> m (DecodeStep ByteString DeserialiseFailure m (SomeMessage st))
decode = \PeerHasAgency pr st
stok     -> (forall s. Decoder s (SomeMessage st))
-> m (DecodeStep ByteString DeserialiseFailure m (SomeMessage st))
forall a.
(forall s. Decoder s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoder (PeerHasAgency pr st -> Decoder s (SomeMessage st)
forall (pr :: PeerRole) (st :: ps) s.
PeerHasAgency pr st -> Decoder s (SomeMessage st)
cborMsgDecode PeerHasAgency pr st
stok)
    }
  where
    convertCborEncoder :: (a -> CBOR.Encoding) -> a -> BS.ByteString
    convertCborEncoder :: (a -> Encoding) -> a -> ByteString
convertCborEncoder a -> Encoding
cborEncode =
        Encoding -> ByteString
CBOR.toStrictByteString
      (Encoding -> ByteString) -> (a -> Encoding) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Encoding
cborEncode

    convertCborDecoder
      :: (forall s. CBOR.Decoder s a)
      -> m (DecodeStep BS.ByteString DeserialiseFailure m a)
    convertCborDecoder :: (forall s. Decoder s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoder forall s. Decoder s a
cborDecode =
        (forall s.
 (forall a. ST s a -> m a)
 -> m (DecodeStep ByteString DeserialiseFailure m a))
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) b.
MonadST m =>
(forall s. (forall a. ST s a -> m a) -> b) -> b
withLiftST (Decoder s a
-> (forall a. ST s a -> m a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall s (m :: * -> *) a.
Functor m =>
Decoder s a
-> (forall b. ST s b -> m b)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoderBS Decoder s a
forall s. Decoder s a
cborDecode)

convertCborDecoderBS
  :: forall s m a. Functor m
  => (CBOR.Decoder s a)
  -> (forall b. ST s b -> m b)
  -> m (DecodeStep BS.ByteString DeserialiseFailure m a)
convertCborDecoderBS :: Decoder s a
-> (forall b. ST s b -> m b)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoderBS Decoder s a
cborDecode forall b. ST s b -> m b
liftST =
    IDecode s a -> DecodeStep ByteString DeserialiseFailure m a
go (IDecode s a -> DecodeStep ByteString DeserialiseFailure m a)
-> m (IDecode s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ST s (IDecode s a) -> m (IDecode s a)
forall b. ST s b -> m b
liftST (Decoder s a -> ST s (IDecode s a)
forall s a. Decoder s a -> ST s (IDecode s a)
CBOR.deserialiseIncremental Decoder s a
cborDecode)
  where
    go :: CBOR.IDecode s a
       -> DecodeStep BS.ByteString DeserialiseFailure m a
    go :: IDecode s a -> DecodeStep ByteString DeserialiseFailure m a
go (CBOR.Done  ByteString
trailing ByteOffset
_ a
x)
      | ByteString -> Bool
BS.null ByteString
trailing       = a
-> Maybe ByteString -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
a -> Maybe bytes -> DecodeStep bytes failure m a
DecodeDone a
x Maybe ByteString
forall a. Maybe a
Nothing
      | Bool
otherwise              = a
-> Maybe ByteString -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
a -> Maybe bytes -> DecodeStep bytes failure m a
DecodeDone a
x (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
trailing)
    go (CBOR.Fail ByteString
_ ByteOffset
_ DeserialiseFailure
failure) = DeserialiseFailure -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
failure -> DecodeStep bytes failure m a
DecodeFail DeserialiseFailure
failure
    go (CBOR.Partial Maybe ByteString -> ST s (IDecode s a)
k)        = (Maybe ByteString
 -> m (DecodeStep ByteString DeserialiseFailure m a))
-> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
(Maybe bytes -> m (DecodeStep bytes failure m a))
-> DecodeStep bytes failure m a
DecodePartial ((IDecode s a -> DecodeStep ByteString DeserialiseFailure m a)
-> m (IDecode s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap IDecode s a -> DecodeStep ByteString DeserialiseFailure m a
go (m (IDecode s a)
 -> m (DecodeStep ByteString DeserialiseFailure m a))
-> (Maybe ByteString -> m (IDecode s a))
-> Maybe ByteString
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ST s (IDecode s a) -> m (IDecode s a)
forall b. ST s b -> m b
liftST (ST s (IDecode s a) -> m (IDecode s a))
-> (Maybe ByteString -> ST s (IDecode s a))
-> Maybe ByteString
-> m (IDecode s a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe ByteString -> ST s (IDecode s a)
k)


-- | Construct a 'Codec' for a CBOR based serialisation format, using lazy
-- 'BS.ByteString's.
--
-- This is an adaptor between the @cborg@ library and the 'Codec' abstraction.
--
-- It takes encode and decode functions for the protocol messages that use the
-- CBOR library encoder and decoder.
--
mkCodecCborLazyBS
  :: forall ps m. MonadST m

  => (forall (pr :: PeerRole) (st :: ps) (st' :: ps).
             PeerHasAgency pr st
          -> Message ps st st' -> CBOR.Encoding)

  -> (forall (pr :: PeerRole) (st :: ps) s.
             PeerHasAgency pr st
          -> CBOR.Decoder s (SomeMessage st))

  -> Codec ps CBOR.DeserialiseFailure m LBS.ByteString
mkCodecCborLazyBS :: (forall (pr :: PeerRole) (st :: ps) (st' :: ps).
 PeerHasAgency pr st -> Message ps st st' -> Encoding)
-> (forall (pr :: PeerRole) (st :: ps) s.
    PeerHasAgency pr st -> Decoder s (SomeMessage st))
-> Codec ps DeserialiseFailure m ByteString
mkCodecCborLazyBS  forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st -> Message ps st st' -> Encoding
cborMsgEncode forall (pr :: PeerRole) (st :: ps) s.
PeerHasAgency pr st -> Decoder s (SomeMessage st)
cborMsgDecode =
    Codec :: forall ps failure (m :: * -> *) bytes.
(forall (pr :: PeerRole) (st :: ps) (st' :: ps).
 PeerHasAgency pr st -> Message ps st st' -> bytes)
-> (forall (pr :: PeerRole) (st :: ps).
    PeerHasAgency pr st
    -> m (DecodeStep bytes failure m (SomeMessage st)))
-> Codec ps failure m bytes
Codec {
      encode :: forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st -> Message ps st st' -> ByteString
encode = \PeerHasAgency pr st
stok Message ps st st'
msg -> (Message ps st st' -> Encoding) -> Message ps st st' -> ByteString
forall a. (a -> Encoding) -> a -> ByteString
convertCborEncoder (PeerHasAgency pr st -> Message ps st st' -> Encoding
forall (pr :: PeerRole) (st :: ps) (st' :: ps).
PeerHasAgency pr st -> Message ps st st' -> Encoding
cborMsgEncode PeerHasAgency pr st
stok) Message ps st st'
msg,
      decode :: forall (pr :: PeerRole) (st :: ps).
PeerHasAgency pr st
-> m (DecodeStep ByteString DeserialiseFailure m (SomeMessage st))
decode = \PeerHasAgency pr st
stok     -> (forall s. Decoder s (SomeMessage st))
-> m (DecodeStep ByteString DeserialiseFailure m (SomeMessage st))
forall a.
(forall s. Decoder s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoder (PeerHasAgency pr st -> Decoder s (SomeMessage st)
forall (pr :: PeerRole) (st :: ps) s.
PeerHasAgency pr st -> Decoder s (SomeMessage st)
cborMsgDecode PeerHasAgency pr st
stok)
    }
  where
    convertCborEncoder :: (a -> CBOR.Encoding) -> a -> LBS.ByteString
    convertCborEncoder :: (a -> Encoding) -> a -> ByteString
convertCborEncoder a -> Encoding
cborEncode =
        Builder -> ByteString
toLazyByteString
      (Builder -> ByteString) -> (a -> Builder) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Encoding -> Builder
CBOR.toBuilder
      (Encoding -> Builder) -> (a -> Encoding) -> a -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Encoding
cborEncode

    convertCborDecoder
      :: (forall s. CBOR.Decoder s a)
      -> m (DecodeStep LBS.ByteString CBOR.DeserialiseFailure m a)
    convertCborDecoder :: (forall s. Decoder s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoder forall s. Decoder s a
cborDecode =
        (forall s.
 (forall a. ST s a -> m a)
 -> m (DecodeStep ByteString DeserialiseFailure m a))
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) b.
MonadST m =>
(forall s. (forall a. ST s a -> m a) -> b) -> b
withLiftST (Decoder s a
-> (forall a. ST s a -> m a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall s (m :: * -> *) a.
Monad m =>
Decoder s a
-> (forall b. ST s b -> m b)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoderLBS Decoder s a
forall s. Decoder s a
cborDecode)

convertCborDecoderLBS
  :: forall s m a. Monad m
  => (CBOR.Decoder s a)
  -> (forall b. ST s b -> m b)
  -> m (DecodeStep LBS.ByteString CBOR.DeserialiseFailure m a)
convertCborDecoderLBS :: Decoder s a
-> (forall b. ST s b -> m b)
-> m (DecodeStep ByteString DeserialiseFailure m a)
convertCborDecoderLBS Decoder s a
cborDecode forall b. ST s b -> m b
liftST =
    [ByteString]
-> IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a)
go [] (IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a))
-> m (IDecode s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ST s (IDecode s a) -> m (IDecode s a)
forall b. ST s b -> m b
liftST (Decoder s a -> ST s (IDecode s a)
forall s a. Decoder s a -> ST s (IDecode s a)
CBOR.deserialiseIncremental Decoder s a
cborDecode)
  where
    -- Have to mediate between a CBOR decoder that consumes strict bytestrings
    -- and our choice here that consumes lazy bytestrings.
    go :: [BS.ByteString] -> CBOR.IDecode s a
       -> m (DecodeStep LBS.ByteString CBOR.DeserialiseFailure m a)
    go :: [ByteString]
-> IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a)
go [] (CBOR.Done  ByteString
trailing ByteOffset
_ a
x)
      | ByteString -> Bool
BS.null ByteString
trailing    = DecodeStep ByteString DeserialiseFailure m a
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
-> Maybe ByteString -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
a -> Maybe bytes -> DecodeStep bytes failure m a
DecodeDone a
x Maybe ByteString
forall a. Maybe a
Nothing)
      | Bool
otherwise           = DecodeStep ByteString DeserialiseFailure m a
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
-> Maybe ByteString -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
a -> Maybe bytes -> DecodeStep bytes failure m a
DecodeDone a
x (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
trailing'))
                                where trailing' :: ByteString
trailing' = ByteString -> ByteString
LBS.fromStrict ByteString
trailing
    go [ByteString]
cs (CBOR.Done  ByteString
trailing ByteOffset
_ a
x) = DecodeStep ByteString DeserialiseFailure m a
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
-> Maybe ByteString -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
a -> Maybe bytes -> DecodeStep bytes failure m a
DecodeDone a
x (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
trailing'))
                                where trailing' :: ByteString
trailing' = [ByteString] -> ByteString
LBS.fromChunks (ByteString
trailing ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
cs)
    go [ByteString]
_  (CBOR.Fail ByteString
_ ByteOffset
_ DeserialiseFailure
e) = DecodeStep ByteString DeserialiseFailure m a
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (DeserialiseFailure -> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
failure -> DecodeStep bytes failure m a
DecodeFail DeserialiseFailure
e)

    -- We keep a bunch of chunks and supply the CBOR decoder with them
    -- until we run out, when we go get another bunch.
    go (ByteString
c:[ByteString]
cs) (CBOR.Partial  Maybe ByteString -> ST s (IDecode s a)
k) = [ByteString]
-> IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a)
go [ByteString]
cs (IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a))
-> m (IDecode s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ST s (IDecode s a) -> m (IDecode s a)
forall b. ST s b -> m b
liftST (Maybe ByteString -> ST s (IDecode s a)
k (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
c))
    go []     (CBOR.Partial  Maybe ByteString -> ST s (IDecode s a)
k) = DecodeStep ByteString DeserialiseFailure m a
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a. Monad m => a -> m a
return (DecodeStep ByteString DeserialiseFailure m a
 -> m (DecodeStep ByteString DeserialiseFailure m a))
-> DecodeStep ByteString DeserialiseFailure m a
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall a b. (a -> b) -> a -> b
$ (Maybe ByteString
 -> m (DecodeStep ByteString DeserialiseFailure m a))
-> DecodeStep ByteString DeserialiseFailure m a
forall bytes failure (m :: * -> *) a.
(Maybe bytes -> m (DecodeStep bytes failure m a))
-> DecodeStep bytes failure m a
DecodePartial ((Maybe ByteString
  -> m (DecodeStep ByteString DeserialiseFailure m a))
 -> DecodeStep ByteString DeserialiseFailure m a)
-> (Maybe ByteString
    -> m (DecodeStep ByteString DeserialiseFailure m a))
-> DecodeStep ByteString DeserialiseFailure m a
forall a b. (a -> b) -> a -> b
$ \Maybe ByteString
mbs -> case Maybe ByteString
mbs of
                                    Maybe ByteString
Nothing -> [ByteString]
-> IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a)
go [] (IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a))
-> m (IDecode s a)
-> m (DecodeStep ByteString DeserialiseFailure m a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ST s (IDecode s a) -> m (IDecode s a)
forall b. ST s b -> m b
liftST (Maybe ByteString -> ST s (IDecode s a)
k Maybe ByteString
forall a. Maybe a
Nothing)
                                    Just ByteString
bs -> [ByteString]
-> IDecode s a -> m (DecodeStep ByteString DeserialiseFailure m a)
go [ByteString]
cs ((Maybe ByteString -> ST s (IDecode s a)) -> IDecode s a
forall s a. (Maybe ByteString -> ST s (IDecode s a)) -> IDecode s a
CBOR.Partial Maybe ByteString -> ST s (IDecode s a)
k)
                                      where cs :: [ByteString]
cs = ByteString -> [ByteString]
LBS.toChunks ByteString
bs

{-# NOINLINE toLazyByteString #-}
toLazyByteString :: BS.Builder -> LBS.ByteString
toLazyByteString :: Builder -> ByteString
toLazyByteString = AllocationStrategy -> ByteString -> Builder -> ByteString
BS.toLazyByteStringWith AllocationStrategy
strategy ByteString
LBS.empty
  where
    -- Buffer strategy and sizes better tuned to our network protocol situation.
    --
    -- The LBS.smallChunkSize is 4k - heap object overheads, so that
    -- it does fit in a 4k overall.
    --
    strategy :: AllocationStrategy
strategy = Int -> Int -> AllocationStrategy
BS.untrimmedStrategy Int
800 Int
LBS.smallChunkSize