{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Wai.Handler.Warp.HTTP2 (
    http2,
    http2server,
) where

import qualified Control.Exception as E
import qualified Data.ByteString as BS
import Data.IORef (readIORef)
import qualified Data.IORef as I
import GHC.Conc.Sync (labelThread, myThreadId)
import qualified Network.HTTP2.Frame as H2
import qualified Network.HTTP2.Server as H2
import Network.Socket (SockAddr)
import Network.Socket.BufferPool
import Network.Wai
import Network.Wai.Internal (ResponseReceived (..))
import qualified System.TimeManager as T

import Network.Wai.Handler.Warp.HTTP2.File
import Network.Wai.Handler.Warp.HTTP2.PushPromise
import Network.Wai.Handler.Warp.HTTP2.Request
import Network.Wai.Handler.Warp.HTTP2.Response
import Network.Wai.Handler.Warp.Imports
import qualified Network.Wai.Handler.Warp.Settings as S
import Network.Wai.Handler.Warp.Types

----------------------------------------------------------------

http2
    :: S.Settings
    -> InternalInfo
    -> Connection
    -> Transport
    -> Application
    -> SockAddr
    -> T.Handle
    -> ByteString
    -> IO ()
http2 :: Settings
-> InternalInfo
-> Connection
-> Transport
-> Application
-> SockAddr
-> Handle
-> ByteString
-> IO ()
http2 Settings
settings InternalInfo
ii Connection
conn Transport
transport Application
app SockAddr
peersa Handle
th ByteString
bs = do
    rawRecvN <- ByteString -> Recv -> IO RecvN
makeRecvN ByteString
bs (Recv -> IO RecvN) -> Recv -> IO RecvN
forall a b. (a -> b) -> a -> b
$ Connection -> Recv
connRecv Connection
conn
    writeBuffer <- readIORef $ connWriteBuffer conn
    -- This thread becomes the sender in http2 library.
    -- In the case of event source, one request comes and one
    -- worker gets busy. But it is likely that the receiver does
    -- not receive any data at all while the sender is sending
    -- output data from the worker. It's not good enough to tickle
    -- the time handler in the receiver only. So, we should tickle
    -- the time handler in both the receiver and the sender.
    let recvN = Handle -> Int -> RecvN -> RecvN
wrappedRecvN Handle
th (Settings -> Int
S.settingsSlowlorisSize Settings
settings) RecvN
rawRecvN
        sendBS ByteString
x = Connection -> ByteString -> IO ()
connSendAll Connection
conn ByteString
x IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
T.tickle Handle
th
        conf =
            H2.Config
                { confWriteBuffer :: Buffer
confWriteBuffer = WriteBuffer -> Buffer
bufBuffer WriteBuffer
writeBuffer
                , confBufferSize :: Int
confBufferSize = WriteBuffer -> Int
bufSize WriteBuffer
writeBuffer
                , confSendAll :: ByteString -> IO ()
confSendAll = ByteString -> IO ()
sendBS
                , confReadN :: RecvN
confReadN = RecvN
recvN
                , confPositionReadMaker :: PositionReadMaker
confPositionReadMaker = InternalInfo -> PositionReadMaker
pReadMaker InternalInfo
ii
                , confTimeoutManager :: Manager
confTimeoutManager = InternalInfo -> Manager
timeoutManager InternalInfo
ii
#if MIN_VERSION_http2(4,2,0)
                , confMySockAddr :: SockAddr
confMySockAddr = Connection -> SockAddr
connMySockAddr Connection
conn
                , confPeerSockAddr :: SockAddr
confPeerSockAddr = SockAddr
peersa
#endif
                }
    checkTLS
    setConnHTTP2 conn True
    H2.run H2.defaultServerConfig conf $
        http2server "Warp HTTP/2" settings ii transport peersa app
  where
    checkTLS :: IO ()
checkTLS = case Transport
transport of
        Transport
TCP -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return () -- direct
        Transport
tls -> Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Transport -> Bool
tls12orLater Transport
tls) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> ErrorCodeId -> ByteString -> IO ()
goaway Connection
conn ErrorCodeId
H2.InadequateSecurity ByteString
"Weak TLS"
    tls12orLater :: Transport -> Bool
tls12orLater Transport
tls = Transport -> Int
tlsMajorVersion Transport
tls Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 Bool -> Bool -> Bool
&& Transport -> Int
tlsMinorVersion Transport
tls Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
3

-- | Converting WAI application to the server type of http2 library.
--
-- Since 3.3.11
http2server
    :: String
    -> S.Settings
    -> InternalInfo
    -> Transport
    -> SockAddr
    -> Application
    -> H2.Server
http2server :: [Char]
-> Settings
-> InternalInfo
-> Transport
-> SockAddr
-> Application
-> Server
http2server [Char]
label Settings
settings InternalInfo
ii Transport
transport SockAddr
addr Application
app Request
h2req0 Aux
aux0 Response -> [PushPromise] -> IO ()
response = do
    tid <- IO ThreadId
myThreadId
    labelThread tid (label ++ " http2server " ++ show addr)
    req <- toWAIRequest h2req0 aux0
    ref <- I.newIORef Nothing
    eResponseReceived <- E.try $ app req $ \Response
rsp -> do
        (h2rsp, st, hasBody) <- Settings
-> InternalInfo
-> Request
-> Response
-> IO (Response, Status, Bool)
fromResponse Settings
settings InternalInfo
ii Request
req Response
rsp
        pps <- if hasBody then fromPushPromises ii req else return []
        I.writeIORef ref $ Just (h2rsp, pps, st)
        _ <- response h2rsp pps
        return ResponseReceived
    case eResponseReceived of
        Right ResponseReceived
ResponseReceived -> do
            Just (h2rsp, pps, st) <- IORef (Maybe (Response, [PushPromise], Status))
-> IO (Maybe (Response, [PushPromise], Status))
forall a. IORef a -> IO a
I.readIORef IORef (Maybe (Response, [PushPromise], Status))
ref
            let msiz = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Maybe Int -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Response -> Maybe Int
H2.responseBodySize Response
h2rsp
            logResponse req st msiz
            mapM_ (logPushPromise req) pps
        Left SomeException
e
          | SomeException -> Bool
forall e. Exception e => e -> Bool
isAsyncException SomeException
e -> SomeException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO SomeException
e
          | Bool
otherwise -> do
            Settings -> Maybe Request -> SomeException -> IO ()
S.settingsOnException Settings
settings (Request -> Maybe Request
forall a. a -> Maybe a
Just Request
req) SomeException
e
            let ersp :: Response
ersp = Settings -> SomeException -> Response
S.settingsOnExceptionResponse Settings
settings SomeException
e
                st :: Status
st = Response -> Status
responseStatus Response
ersp
            (h2rsp', _, _) <- Settings
-> InternalInfo
-> Request
-> Response
-> IO (Response, Status, Bool)
fromResponse Settings
settings InternalInfo
ii Request
req Response
ersp
            let msiz = Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Maybe Int -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Response -> Maybe Int
H2.responseBodySize Response
h2rsp'
            _ <- response h2rsp' []
            logResponse req st msiz
    return ()
  where
    toWAIRequest :: Request -> Aux -> IO Request
toWAIRequest Request
h2req Aux
aux = InternalInfo -> Settings -> SockAddr -> ToReq
toRequest InternalInfo
ii Settings
settings SockAddr
addr TokenHeaderTable
hdr Maybe Int
bdylen Recv
bdy Handle
th Transport
transport
      where
        !hdr :: TokenHeaderTable
hdr = Request -> TokenHeaderTable
H2.requestHeaders Request
h2req
        !bdy :: Recv
bdy = Request -> Recv
H2.getRequestBodyChunk Request
h2req
        !bdylen :: Maybe Int
bdylen = Request -> Maybe Int
H2.requestBodySize Request
h2req
        !th :: Handle
th = Aux -> Handle
H2.auxTimeHandle Aux
aux

    logResponse :: Request -> Status -> Maybe Integer -> IO ()
logResponse = Settings -> Request -> Status -> Maybe Integer -> IO ()
S.settingsLogger Settings
settings

    logPushPromise :: Request -> PushPromise -> IO ()
logPushPromise Request
req PushPromise
pp = Request -> ByteString -> Integer -> IO ()
logger Request
req ByteString
path Integer
siz
      where
        !logger :: Request -> ByteString -> Integer -> IO ()
logger = Settings -> Request -> ByteString -> Integer -> IO ()
S.settingsServerPushLogger Settings
settings
        !path :: ByteString
path = PushPromise -> ByteString
H2.promiseRequestPath PushPromise
pp
        !siz :: Integer
siz = case Response -> Maybe Int
H2.responseBodySize (Response -> Maybe Int) -> Response -> Maybe Int
forall a b. (a -> b) -> a -> b
$ PushPromise -> Response
H2.promiseResponse PushPromise
pp of
            Maybe Int
Nothing -> Integer
0
            Just Int
s -> Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s

wrappedRecvN
    :: T.Handle -> Int -> (BufSize -> IO ByteString) -> (BufSize -> IO ByteString)
wrappedRecvN :: Handle -> Int -> RecvN -> RecvN
wrappedRecvN Handle
th Int
slowlorisSize RecvN
readN Int
bufsize = do
    bs <- (SomeException -> Recv) -> Recv -> Recv
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle SomeException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ RecvN
readN Int
bufsize
    -- TODO: think about the slowloris protection in HTTP2: current code
    -- might open a slow-loris attack vector. Rather than timing we should
    -- consider limiting the per-client connections assuming that in HTTP2
    -- we should allow only few connections per host (real-world
    -- deployments with large NATs may be trickier).
    when
        (BS.length bs > 0 && BS.length bs >= slowlorisSize || bufsize <= slowlorisSize) $
        T.tickle th
    return bs
  where
    handler :: E.SomeException -> IO ByteString
    handler :: SomeException -> Recv
handler = Recv -> SomeException -> Recv
forall a. IO a -> SomeException -> IO a
throughAsync (ByteString -> Recv
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
"")

-- connClose must not be called here since Run:fork calls it
goaway :: Connection -> H2.ErrorCodeId -> ByteString -> IO ()
goaway :: Connection -> ErrorCodeId -> ByteString -> IO ()
goaway Connection{IO ()
Recv
IORef Bool
IORef WriteBuffer
SockAddr
[ByteString] -> IO ()
RecvBuf
ByteString -> IO ()
SendFile
connRecv :: Connection -> Recv
connWriteBuffer :: Connection -> IORef WriteBuffer
connSendAll :: Connection -> ByteString -> IO ()
connMySockAddr :: Connection -> SockAddr
connSendMany :: [ByteString] -> IO ()
connSendAll :: ByteString -> IO ()
connSendFile :: SendFile
connClose :: IO ()
connRecv :: Recv
connRecvBuf :: RecvBuf
connWriteBuffer :: IORef WriteBuffer
connHTTP2 :: IORef Bool
connMySockAddr :: SockAddr
connHTTP2 :: Connection -> IORef Bool
connRecvBuf :: Connection -> RecvBuf
connClose :: Connection -> IO ()
connSendFile :: Connection -> SendFile
connSendMany :: Connection -> [ByteString] -> IO ()
..} ErrorCodeId
etype ByteString
debugmsg = ByteString -> IO ()
connSendAll ByteString
bytestream
  where
    einfo :: EncodeInfo
einfo = (FrameFlags -> FrameFlags) -> Int -> EncodeInfo
H2.encodeInfo FrameFlags -> FrameFlags
forall a. a -> a
id Int
0
    frame :: FramePayload
frame = Int -> ErrorCodeId -> ByteString -> FramePayload
H2.GoAwayFrame Int
0 ErrorCodeId
etype ByteString
debugmsg
    bytestream :: ByteString
bytestream = EncodeInfo -> FramePayload -> ByteString
H2.encodeFrame EncodeInfo
einfo FramePayload
frame