{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Cardano.Shell.NodeIPC.General
( NodeChannel
, NodeChannelError(..)
, NodeChannelFinished(..)
, setupNodeChannel
, runNodeChannel
) where
import Cardano.Prelude
import Control.Concurrent.Async (concurrently_, race)
import Control.Concurrent.MVar (MVar, newEmptyMVar, putMVar, takeMVar)
import Control.Exception (IOException, catch, tryJust)
import Control.Monad (forever)
import Data.Aeson (FromJSON (..), ToJSON (..), eitherDecode, encode)
import Data.Bifunctor (first)
import Data.Binary.Get (getWord32le, getWord64le, runGet)
import Data.Binary.Put (putLazyByteString, putWord32le, putWord64le,
runPut)
import Data.Functor (($>))
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Word (Word32, Word64)
import GHC.IO.Handle.FD (fdToHandle)
import System.Environment (lookupEnv)
import System.Info (os)
import System.IO (Handle, hFlush, hGetLine, hSetNewlineMode,
noNewlineTranslation)
import System.IO.Error (IOError, userError)
import Text.Read (readEither)
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Char8 as L8
import qualified Data.Text as T
data NodeChannelError
= NodeChannelDisabled
| NodeChannelBadFD Text
deriving (Int -> NodeChannelError -> ShowS
[NodeChannelError] -> ShowS
NodeChannelError -> String
(Int -> NodeChannelError -> ShowS)
-> (NodeChannelError -> String)
-> ([NodeChannelError] -> ShowS)
-> Show NodeChannelError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NodeChannelError] -> ShowS
$cshowList :: [NodeChannelError] -> ShowS
show :: NodeChannelError -> String
$cshow :: NodeChannelError -> String
showsPrec :: Int -> NodeChannelError -> ShowS
$cshowsPrec :: Int -> NodeChannelError -> ShowS
Show, NodeChannelError -> NodeChannelError -> Bool
(NodeChannelError -> NodeChannelError -> Bool)
-> (NodeChannelError -> NodeChannelError -> Bool)
-> Eq NodeChannelError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NodeChannelError -> NodeChannelError -> Bool
$c/= :: NodeChannelError -> NodeChannelError -> Bool
== :: NodeChannelError -> NodeChannelError -> Bool
$c== :: NodeChannelError -> NodeChannelError -> Bool
Eq)
newtype NodeChannelFinished = NodeChannelFinished IOError
newtype NodeChannel = NodeChannel Handle
setupNodeChannel :: IO (Either NodeChannelError NodeChannel)
setupNodeChannel :: IO (Either NodeChannelError NodeChannel)
setupNodeChannel = (String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"" (Maybe String -> String) -> IO (Maybe String) -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO (Maybe String)
lookupEnv String
"NODE_CHANNEL_FD") IO String
-> (String -> IO (Either NodeChannelError NodeChannel))
-> IO (Either NodeChannelError NodeChannel)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
String
"" -> Either NodeChannelError NodeChannel
-> IO (Either NodeChannelError NodeChannel)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NodeChannelError -> Either NodeChannelError NodeChannel
forall a b. a -> Either a b
Left NodeChannelError
NodeChannelDisabled)
String
var -> case String -> Either String FD
forall a. Read a => String -> Either String a
readEither String
var of
Left String
err -> Either NodeChannelError NodeChannel
-> IO (Either NodeChannelError NodeChannel)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either NodeChannelError NodeChannel
-> IO (Either NodeChannelError NodeChannel))
-> (Text -> Either NodeChannelError NodeChannel)
-> Text
-> IO (Either NodeChannelError NodeChannel)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NodeChannelError -> Either NodeChannelError NodeChannel
forall a b. a -> Either a b
Left (NodeChannelError -> Either NodeChannelError NodeChannel)
-> (Text -> NodeChannelError)
-> Text
-> Either NodeChannelError NodeChannel
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> NodeChannelError
NodeChannelBadFD (Text -> IO (Either NodeChannelError NodeChannel))
-> Text -> IO (Either NodeChannelError NodeChannel)
forall a b. (a -> b) -> a -> b
$
Text
"unable to parse NODE_CHANNEL_FD: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
err
Right FD
fd -> (IOException -> Maybe NodeChannelError)
-> IO NodeChannel -> IO (Either NodeChannelError NodeChannel)
forall e b a.
Exception e =>
(e -> Maybe b) -> IO a -> IO (Either b a)
tryJust IOException -> Maybe NodeChannelError
handleBadFd (Handle -> NodeChannel
NodeChannel (Handle -> NodeChannel) -> IO Handle -> IO NodeChannel
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FD -> IO Handle
fdToHandle FD
fd)
where
handleBadFd :: IOException -> Maybe NodeChannelError
handleBadFd :: IOException -> Maybe NodeChannelError
handleBadFd = NodeChannelError -> Maybe NodeChannelError
forall a. a -> Maybe a
Just (NodeChannelError -> Maybe NodeChannelError)
-> (IOException -> NodeChannelError)
-> IOException
-> Maybe NodeChannelError
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> NodeChannelError
NodeChannelBadFD (Text -> NodeChannelError)
-> (IOException -> Text) -> IOException -> NodeChannelError
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Text
T.pack (String -> Text) -> (IOException -> String) -> IOException -> Text
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IOException -> String
forall a b. (Show a, ConvertText String b) => a -> b
show
runNodeChannel
:: (FromJSON msgin, ToJSON msgout)
=> (Either Text msgin -> IO (Maybe msgout))
-> ((msgout -> IO ()) -> IO a)
-> NodeChannel
-> IO (Either NodeChannelFinished a)
runNodeChannel :: (Either Text msgin -> IO (Maybe msgout))
-> ((msgout -> IO ()) -> IO a)
-> NodeChannel
-> IO (Either NodeChannelFinished a)
runNodeChannel Either Text msgin -> IO (Maybe msgout)
onMsg (msgout -> IO ()) -> IO a
handleMsg (NodeChannel Handle
h) = do
MVar msgout
chan <- IO (MVar msgout)
forall a. IO (MVar a)
newEmptyMVar
let ipc :: IO NodeChannelFinished
ipc = Handle
-> MVar msgout
-> (Either Text msgin -> IO (Maybe msgout))
-> IO NodeChannelFinished
forall msgin msgout.
(FromJSON msgin, ToJSON msgout) =>
Handle
-> MVar msgout
-> (Either Text msgin -> IO (Maybe msgout))
-> IO NodeChannelFinished
ipcListener Handle
h MVar msgout
chan Either Text msgin -> IO (Maybe msgout)
onMsg
action' :: IO a
action' = (msgout -> IO ()) -> IO a
handleMsg (MVar msgout -> msgout -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar msgout
chan)
IO NodeChannelFinished -> IO a -> IO (Either NodeChannelFinished a)
forall a b. IO a -> IO b -> IO (Either a b)
race IO NodeChannelFinished
ipc IO a
action'
ipcListener
:: forall msgin msgout. (FromJSON msgin, ToJSON msgout)
=> Handle
-> MVar msgout
-> (Either Text msgin -> IO (Maybe msgout))
-> IO NodeChannelFinished
ipcListener :: Handle
-> MVar msgout
-> (Either Text msgin -> IO (Maybe msgout))
-> IO NodeChannelFinished
ipcListener Handle
h MVar msgout
chan Either Text msgin -> IO (Maybe msgout)
onMsg = IOException -> NodeChannelFinished
NodeChannelFinished (IOException -> NodeChannelFinished)
-> IO IOException -> IO NodeChannelFinished
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> do
Handle -> NewlineMode -> IO ()
hSetNewlineMode Handle
h NewlineMode
noNewlineTranslation
(IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
concurrently_ IO ()
replyLoop IO ()
sendLoop IO () -> IOException -> IO IOException
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> IOException
unexpected) IO IOException -> (IOException -> IO IOException) -> IO IOException
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` IOException -> IO IOException
forall (f :: * -> *) a. Applicative f => a -> f a
pure
where
sendLoop, replyLoop :: IO ()
replyLoop :: IO ()
replyLoop = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO (Either Text msgin)
recvMsg IO (Either Text msgin)
-> (Either Text msgin -> IO (Maybe msgout)) -> IO (Maybe msgout)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either Text msgin -> IO (Maybe msgout)
onMsg IO (Maybe msgout) -> (Maybe msgout -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe msgout -> IO ()
maybeSend)
sendLoop :: IO ()
sendLoop = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (MVar msgout -> IO msgout
forall a. MVar a -> IO a
takeMVar MVar msgout
chan IO msgout -> (msgout -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= msgout -> IO ()
sendMsg)
recvMsg :: IO (Either Text msgin)
recvMsg :: IO (Either Text msgin)
recvMsg = (String -> Text) -> Either String msgin -> Either Text msgin
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> Text
T.pack (Either String msgin -> Either Text msgin)
-> (ByteString -> Either String msgin)
-> ByteString
-> Either Text msgin
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ByteString -> Either String msgin
forall a. FromJSON a => ByteString -> Either String a
eitherDecode (ByteString -> Either Text msgin)
-> IO ByteString -> IO (Either Text msgin)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Handle -> IO ByteString
readMessage Handle
h
sendMsg :: msgout -> IO ()
sendMsg :: msgout -> IO ()
sendMsg = Handle -> ByteString -> IO ()
sendMessage Handle
h (ByteString -> IO ()) -> (msgout -> ByteString) -> msgout -> IO ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. msgout -> ByteString
forall a. ToJSON a => a -> ByteString
encode
maybeSend :: Maybe msgout -> IO ()
maybeSend :: Maybe msgout -> IO ()
maybeSend = IO () -> (msgout -> IO ()) -> Maybe msgout -> IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (MVar msgout -> msgout -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar msgout
chan)
unexpected :: IOException
unexpected = String -> IOException
userError String
"ipcListener: unreachable code"
readMessage :: Handle -> IO BL.ByteString
readMessage :: Handle -> IO ByteString
readMessage = if Bool
isWindows then Handle -> IO ByteString
readMessageWindows else Handle -> IO ByteString
readMessagePosix
readMessageWindows :: Handle -> IO BL.ByteString
readMessageWindows :: Handle -> IO ByteString
readMessageWindows Handle
h = do
Word32
_int1 <- Handle -> IO Word32
readInt32 Handle
h
Word32
_int2 <- Handle -> IO Word32
readInt32 Handle
h
Word64
size <- Handle -> IO Word64
readInt64 Handle
h
Handle -> Int -> IO ByteString
BL.hGet Handle
h (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
size
where
readInt64 :: Handle -> IO Word64
readInt64 :: Handle -> IO Word64
readInt64 Handle
hnd = do
ByteString
bs <- Handle -> Int -> IO ByteString
BL.hGet Handle
hnd Int
8
Word64 -> IO Word64
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word64 -> IO Word64) -> Word64 -> IO Word64
forall a b. (a -> b) -> a -> b
$ Get Word64 -> ByteString -> Word64
forall a. Get a -> ByteString -> a
runGet Get Word64
getWord64le ByteString
bs
readInt32 :: Handle -> IO Word32
readInt32 :: Handle -> IO Word32
readInt32 Handle
hnd = do
ByteString
bs <- Handle -> Int -> IO ByteString
BL.hGet Handle
hnd Int
4
Word32 -> IO Word32
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$ Get Word32 -> ByteString -> Word32
forall a. Get a -> ByteString -> a
runGet Get Word32
getWord32le ByteString
bs
readMessagePosix :: Handle -> IO BL.ByteString
readMessagePosix :: Handle -> IO ByteString
readMessagePosix = (String -> ByteString) -> IO String -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> ByteString
L8.pack (IO String -> IO ByteString)
-> (Handle -> IO String) -> Handle -> IO ByteString
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Handle -> IO String
hGetLine
sendMessage :: Handle -> BL.ByteString -> IO ()
sendMessage :: Handle -> ByteString -> IO ()
sendMessage Handle
h ByteString
msg = Handle -> ByteString -> IO ()
send Handle
h ByteString
msg IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
h
where
send :: Handle -> ByteString -> IO ()
send = if Bool
isWindows then Handle -> ByteString -> IO ()
sendMessageWindows else Handle -> ByteString -> IO ()
sendMessagePosix
sendMessageWindows :: Handle -> BL.ByteString -> IO ()
sendMessageWindows :: Handle -> ByteString -> IO ()
sendMessageWindows = Word32 -> Word32 -> Handle -> ByteString -> IO ()
sendMessageWindows' Word32
1 Word32
0
sendMessageWindows' :: Word32 -> Word32 -> Handle -> BL.ByteString -> IO ()
sendMessageWindows' :: Word32 -> Word32 -> Handle -> ByteString -> IO ()
sendMessageWindows' Word32
int1 Word32
int2 Handle
h ByteString
blob =
Handle -> ByteString -> IO ()
L8.hPut Handle
h (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$ [Put] -> Put
forall a. Monoid a => [a] -> a
mconcat [Put]
parts
where
blob' :: ByteString
blob' = ByteString
blob ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n"
parts :: [Put]
parts =
[ Word32 -> Put
putWord32le Word32
int1
, Word32 -> Put
putWord32le Word32
int2
, Word64 -> Put
putWord64le (Word64 -> Put) -> Word64 -> Put
forall a b. (a -> b) -> a -> b
$ Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Word64) -> Int64 -> Word64
forall a b. (a -> b) -> a -> b
$ ByteString -> Int64
BL.length ByteString
blob'
, ByteString -> Put
putLazyByteString ByteString
blob'
]
sendMessagePosix :: Handle -> BL.ByteString -> IO ()
sendMessagePosix :: Handle -> ByteString -> IO ()
sendMessagePosix = Handle -> ByteString -> IO ()
L8.hPutStrLn
isWindows :: Bool
isWindows :: Bool
isWindows = String
os String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"windows"