{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An extension of 'Network.TypedProtocol.Channel', with additional 'Channel'
-- implementations.
--
module Network.Mux.Channel
  ( Channel (..)
  , createBufferConnectedChannels
  , createPipeConnectedChannels
#if !defined(mingw32_HOST_OS)
  , createSocketConnectedChannels
#endif
  , withFifosAsChannel
  , socketAsChannel
  , channelEffect
  , delayChannel
  , loggingChannel
  ) where

import qualified Data.ByteString               as BS
import qualified Data.ByteString.Lazy          as LBS
import qualified Data.ByteString.Lazy.Internal as LBS (smallChunkSize)
import qualified System.Process as IO (createPipe)
import qualified System.IO      as IO
                   ( Handle, withFile, IOMode(..), hFlush, hIsEOF )
import qualified Network.Socket            as Socket
import qualified Network.Socket.ByteString as Socket

import           Control.Monad.Class.MonadSTM
import           Control.Monad.Class.MonadTimer
import           Control.Monad.Class.MonadSay


data Channel m = Channel {

    -- | Write bytes to the channel.
    --
    -- It maybe raise exceptions.
    --
    Channel m -> ByteString -> m ()
send :: LBS.ByteString -> m (),

    -- | Read some input from the channel, or @Nothing@ to indicate EOF.
    --
    -- Note that having received EOF it is still possible to send.
    -- The EOF condition is however monotonic.
    --
    -- It may raise exceptions (as appropriate for the monad and kind of
    -- channel).
    --
    Channel m -> m (Maybe ByteString)
recv :: m (Maybe LBS.ByteString)
  }


-- | Make a 'Channel' from a pair of IO 'Handle's, one for reading and one
-- for writing.
--
-- The Handles should be open in the appropriate read or write mode, and in
-- binary mode. Writes are flushed after each write, so it is safe to use
-- a buffering mode.
--
-- For bidirectional handles it is safe to pass the same handle for both.
--
handlesAsChannel :: IO.Handle -- ^ Read handle
                 -> IO.Handle -- ^ Write handle
                 -> Channel IO
handlesAsChannel :: Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndRead Handle
hndWrite =
    Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
  where
    send :: LBS.ByteString -> IO ()
    send :: ByteString -> IO ()
send ByteString
chunk = do
      Handle -> ByteString -> IO ()
LBS.hPut Handle
hndWrite ByteString
chunk
      Handle -> IO ()
IO.hFlush Handle
hndWrite

    recv :: IO (Maybe LBS.ByteString)
    recv :: IO (Maybe ByteString)
recv = do
      Bool
eof <- Handle -> IO Bool
IO.hIsEOF Handle
hndRead
      if Bool
eof
        then Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (ByteString -> ByteString) -> ByteString -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict (ByteString -> Maybe ByteString)
-> IO ByteString -> IO (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Handle -> Int -> IO ByteString
BS.hGetSome Handle
hndRead Int
LBS.smallChunkSize

-- | Create a pair of 'Channel's that are connected internally.
--
-- This is intended for inter-thread communication, such as between a
-- multiplexing thread and a thread running a peer.
--
-- It uses lazy 'ByteString's but it ensures that data written to the channel
-- is /fully evaluated/ first. This ensures that any work to serialise the data
-- takes place on the /writer side and not the reader side/.
--
createBufferConnectedChannels :: forall m. MonadSTM m
                              => m (Channel m,
                                    Channel m)
createBufferConnectedChannels :: m (Channel m, Channel m)
createBufferConnectedChannels = do
    TMVar_ (STM m) ByteString
bufferA <- m (TMVar_ (STM m) ByteString)
forall (m :: * -> *) a. MonadSTM m => m (TMVar m a)
newEmptyTMVarIO
    TMVar_ (STM m) ByteString
bufferB <- m (TMVar_ (STM m) ByteString)
forall (m :: * -> *) a. MonadSTM m => m (TMVar m a)
newEmptyTMVarIO

    (Channel m, Channel m) -> m (Channel m, Channel m)
forall (m :: * -> *) a. Monad m => a -> m a
return (TMVar_ (STM m) ByteString -> TMVar_ (STM m) ByteString -> Channel m
buffersAsChannel TMVar_ (STM m) ByteString
bufferB TMVar_ (STM m) ByteString
bufferA,
            TMVar_ (STM m) ByteString -> TMVar_ (STM m) ByteString -> Channel m
buffersAsChannel TMVar_ (STM m) ByteString
bufferA TMVar_ (STM m) ByteString
bufferB)
  where
    buffersAsChannel :: TMVar_ (STM m) ByteString -> TMVar_ (STM m) ByteString -> Channel m
buffersAsChannel TMVar_ (STM m) ByteString
bufferRead TMVar_ (STM m) ByteString
bufferWrite =
        Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: ByteString -> m ()
send, m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv}
      where
        send :: LBS.ByteString -> m ()
        send :: ByteString -> m ()
send ByteString
x = [m ()] -> m ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ STM m () -> m ()
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar_ (STM m) ByteString -> ByteString -> STM m ()
forall (stm :: * -> *) a.
MonadSTMTx stm =>
TMVar_ stm a -> a -> stm ()
putTMVar TMVar_ (STM m) ByteString
bufferWrite ByteString
c)
                           | !ByteString
c <- ByteString -> [ByteString]
LBS.toChunks ByteString
x ]
                           -- Evaluate the chunk c /before/ doing the STM
                           -- transaction to write it to the buffer.

        recv :: m (Maybe LBS.ByteString)
        recv :: m (Maybe ByteString)
recv   = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (ByteString -> ByteString) -> ByteString -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.fromStrict (ByteString -> Maybe ByteString)
-> m ByteString -> m (Maybe ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM m ByteString -> m ByteString
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (TMVar_ (STM m) ByteString -> STM m ByteString
forall (stm :: * -> *) a. MonadSTMTx stm => TMVar_ stm a -> stm a
takeTMVar TMVar_ (STM m) ByteString
bufferRead)


-- | Create a local pipe, with both ends in this process, and expose that as
-- a pair of 'Channel's, one for each end.
--
-- This is primarily for testing purposes since it does not allow actual IPC.
--
createPipeConnectedChannels :: IO (Channel IO,
                                   Channel IO)
createPipeConnectedChannels :: IO (Channel IO, Channel IO)
createPipeConnectedChannels = do
    -- Create two pipes (each one is unidirectional) to make both ends of
    -- a bidirectional channel
    (Handle
hndReadA, Handle
hndWriteB) <- IO (Handle, Handle)
IO.createPipe
    (Handle
hndReadB, Handle
hndWriteA) <- IO (Handle, Handle)
IO.createPipe

    (Channel IO, Channel IO) -> IO (Channel IO, Channel IO)
forall (m :: * -> *) a. Monad m => a -> m a
return (Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndReadA Handle
hndWriteA,
            Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndReadB Handle
hndWriteB)

-- | Open a pair of Unix FIFOs, and expose that as a 'Channel'.
--
-- The peer process needs to open the same files but the other way around,
-- for writing and reading.
--
-- This is primarily for the purpose of demonstrations that use communication
-- between multiple local processes. It is Unix specific.
--
withFifosAsChannel :: FilePath -- ^ FIFO for reading
                   -> FilePath -- ^ FIFO for writing
                   -> (Channel IO -> IO a) -> IO a
withFifosAsChannel :: FilePath -> FilePath -> (Channel IO -> IO a) -> IO a
withFifosAsChannel FilePath
fifoPathRead FilePath
fifoPathWrite Channel IO -> IO a
action =
    FilePath -> IOMode -> (Handle -> IO a) -> IO a
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
fifoPathRead  IOMode
IO.ReadMode  ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
hndRead  ->
    FilePath -> IOMode -> (Handle -> IO a) -> IO a
forall r. FilePath -> IOMode -> (Handle -> IO r) -> IO r
IO.withFile FilePath
fifoPathWrite IOMode
IO.WriteMode ((Handle -> IO a) -> IO a) -> (Handle -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Handle
hndWrite ->
      let channel :: Channel IO
channel = Handle -> Handle -> Channel IO
handlesAsChannel Handle
hndRead Handle
hndWrite
       in Channel IO -> IO a
action Channel IO
channel


-- | Make a 'Channel' from a 'Socket'. The socket must be a stream socket
--- type and status connected.
---
socketAsChannel :: Socket.Socket -> Channel IO
socketAsChannel :: Socket -> Channel IO
socketAsChannel Socket
socket =
    Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{ByteString -> IO ()
send :: ByteString -> IO ()
send :: ByteString -> IO ()
send, IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv :: IO (Maybe ByteString)
recv}
  where
    send :: LBS.ByteString -> IO ()
    send :: ByteString -> IO ()
send ByteString
chunks =
     -- Use vectored writes.
     Socket -> [ByteString] -> IO ()
Socket.sendMany Socket
socket (ByteString -> [ByteString]
LBS.toChunks ByteString
chunks)
     -- TODO: limit write sizes, or break them into multiple sends.

    recv :: IO (Maybe LBS.ByteString)
    recv :: IO (Maybe ByteString)
recv = do
      -- We rely on the behaviour of stream sockets that a zero length chunk
      -- indicates EOF.
      ByteString
chunk <- Socket -> Int -> IO ByteString
Socket.recv Socket
socket Int
LBS.smallChunkSize
      if ByteString -> Bool
BS.null ByteString
chunk
        then Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        else Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> ByteString
LBS.fromStrict ByteString
chunk))

#if !defined(mingw32_HOST_OS)
--- | Create a local socket, with both ends in this process, and expose that as
--- a pair of 'ByteChannel's, one for each end.
---
--- This is primarily for testing purposes since it does not allow actual IPC.
---
createSocketConnectedChannels :: Socket.Family -- ^ Usually AF_UNIX or AF_INET
                              -> IO (Channel IO,
                                     Channel IO)
createSocketConnectedChannels :: Family -> IO (Channel IO, Channel IO)
createSocketConnectedChannels Family
family = do
   -- Create a socket pair to make both ends of a bidirectional channel
   (Socket
socketA, Socket
socketB) <- Family -> SocketType -> ProtocolNumber -> IO (Socket, Socket)
Socket.socketPair Family
family SocketType
Socket.Stream
                                           ProtocolNumber
Socket.defaultProtocol

   (Channel IO, Channel IO) -> IO (Channel IO, Channel IO)
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> Channel IO
socketAsChannel Socket
socketA,
           Socket -> Channel IO
socketAsChannel Socket
socketB)
#endif

channelEffect :: forall m.
                 Monad m
              => (LBS.ByteString -> m ())       -- ^ Action before 'send'
              -> (Maybe LBS.ByteString -> m ()) -- ^ Action after 'recv'
              -> Channel m
              -> Channel m
channelEffect :: (ByteString -> m ())
-> (Maybe ByteString -> m ()) -> Channel m -> Channel m
channelEffect ByteString -> m ()
beforeSend Maybe ByteString -> m ()
afterRecv Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: forall (m :: * -> *). Channel m -> ByteString -> m ()
send, m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: forall (m :: * -> *). Channel m -> m (Maybe ByteString)
recv} =
    Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel{
      send :: ByteString -> m ()
send = \ByteString
x -> do
        ByteString -> m ()
beforeSend ByteString
x
        ByteString -> m ()
send ByteString
x

    , recv :: m (Maybe ByteString)
recv = do
        Maybe ByteString
mx <- m (Maybe ByteString)
recv
        Maybe ByteString -> m ()
afterRecv Maybe ByteString
mx
        Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
mx
    }

-- | Delay a channel on the receiver end.
--
-- This is intended for testing, as a crude approximation of network delays.
-- More accurate models along these lines are of course possible.
--
delayChannel :: ( MonadSTM m
                , MonadTimer m
                )
             => DiffTime
             -> Channel m
             -> Channel m
delayChannel :: DiffTime -> Channel m -> Channel m
delayChannel DiffTime
delay = (ByteString -> m ())
-> (Maybe ByteString -> m ()) -> Channel m -> Channel m
forall (m :: * -> *).
Monad m =>
(ByteString -> m ())
-> (Maybe ByteString -> m ()) -> Channel m -> Channel m
channelEffect (\ByteString
_ -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                                   (\Maybe ByteString
_ -> DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
delay)

-- | Channel which logs sent and received messages.
--
loggingChannel :: ( MonadSay m
                  , Show id
                  )
               => id
               -> Channel m
               -> Channel m
loggingChannel :: id -> Channel m -> Channel m
loggingChannel id
ident Channel{ByteString -> m ()
send :: ByteString -> m ()
send :: forall (m :: * -> *). Channel m -> ByteString -> m ()
send,m (Maybe ByteString)
recv :: m (Maybe ByteString)
recv :: forall (m :: * -> *). Channel m -> m (Maybe ByteString)
recv} =
  Channel :: forall (m :: * -> *).
(ByteString -> m ()) -> m (Maybe ByteString) -> Channel m
Channel {
    send :: ByteString -> m ()
send = ByteString -> m ()
loggingSend,
    recv :: m (Maybe ByteString)
recv = m (Maybe ByteString)
loggingRecv
  }
 where
  loggingSend :: ByteString -> m ()
loggingSend ByteString
a = do
    FilePath -> m ()
forall (m :: * -> *). MonadSay m => FilePath -> m ()
say (id -> FilePath
forall a. Show a => a -> FilePath
show id
ident FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
":send:" FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ ByteString -> FilePath
forall a. Show a => a -> FilePath
show ByteString
a)
    ByteString -> m ()
send ByteString
a

  loggingRecv :: m (Maybe ByteString)
loggingRecv = do
    Maybe ByteString
msg <- m (Maybe ByteString)
recv
    case Maybe ByteString
msg of
      Maybe ByteString
Nothing -> () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just ByteString
a  -> FilePath -> m ()
forall (m :: * -> *). MonadSay m => FilePath -> m ()
say (id -> FilePath
forall a. Show a => a -> FilePath
show id
ident FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
":recv:" FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ ByteString -> FilePath
forall a. Show a => a -> FilePath
show ByteString
a)
    Maybe ByteString -> m (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
msg