{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Conduit where

import Control.Exception (assert, throwIO)
import qualified Data.ByteString as S
import qualified Data.IORef as I
import Data.Word8 (_0, _9, _A, _F, _a, _cr, _f, _lf)

import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

----------------------------------------------------------------

-- | Contains a @Source@ and a byte count that is still to be read in.
data ISource = ISource !Source !(I.IORef Int)

mkISource :: Source -> Int -> IO ISource
mkISource :: Source -> Int -> IO ISource
mkISource Source
src Int
cnt = do
    ref <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
I.newIORef Int
cnt
    return $! ISource src ref

-- | Given an @IsolatedBSSource@ provide a @Source@ that only allows up to the
-- specified number of bytes to be passed downstream. All leftovers should be
-- retained within the @Source@. If there are not enough bytes available,
-- throws a @ConnectionClosedByPeer@ exception.
readISource :: ISource -> IO ByteString
readISource :: ISource -> IO ByteString
readISource (ISource Source
src IORef Int
ref) = do
    count <- IORef Int -> IO Int
forall a. IORef a -> IO a
I.readIORef IORef Int
ref
    if count == 0
        then return S.empty
        else do
            bs <- readSource src

            -- If no chunk available, then there aren't enough bytes in the
            -- stream. Throw a ConnectionClosedByPeer
            when (S.null bs) $ throwIO ConnectionClosedByPeer

            let -- How many of the bytes in this chunk to send downstream
                toSend = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
count (ByteString -> Int
S.length ByteString
bs)
                -- How many bytes will still remain to be sent downstream
                count' = Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
toSend

            I.writeIORef ref count'

            if count' > 0
                then -- The expected count is greater than the size of the
                -- chunk we just read. Send the entire chunk
                -- downstream, and then loop on this function for the
                -- next chunk.
                    return bs
                else do
                    -- Some of the bytes in this chunk should not be sent
                    -- downstream. Split up the chunk into the sent and
                    -- not-sent parts, add the not-sent parts onto the new
                    -- source, and send the rest of the chunk downstream.
                    let (x, y) = S.splitAt toSend bs
                    leftoverSource src y
                    assert (count' == 0) $ return x

----------------------------------------------------------------

data CSource = CSource !Source !(I.IORef ChunkState)

data ChunkState
    = NeedLen
    | NeedLenNewline
    | HaveLen Word
    | DoneChunking
    deriving (Int -> ChunkState -> ShowS
[ChunkState] -> ShowS
ChunkState -> String
(Int -> ChunkState -> ShowS)
-> (ChunkState -> String)
-> ([ChunkState] -> ShowS)
-> Show ChunkState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ChunkState -> ShowS
showsPrec :: Int -> ChunkState -> ShowS
$cshow :: ChunkState -> String
show :: ChunkState -> String
$cshowList :: [ChunkState] -> ShowS
showList :: [ChunkState] -> ShowS
Show)

mkCSource :: Source -> IO CSource
mkCSource :: Source -> IO CSource
mkCSource Source
src = do
    ref <- ChunkState -> IO (IORef ChunkState)
forall a. a -> IO (IORef a)
I.newIORef ChunkState
NeedLen
    return $! CSource src ref

readCSource :: CSource -> IO ByteString
readCSource :: CSource -> IO ByteString
readCSource (CSource Source
src IORef ChunkState
ref) = do
    mlen <- IORef ChunkState -> IO ChunkState
forall a. IORef a -> IO a
I.readIORef IORef ChunkState
ref
    go mlen
  where
    withLen :: Word -> ByteString -> IO ByteString
withLen Word
0 ByteString
bs = do
        Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
bs
        IO ()
dropCRLF
        ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
S.empty ChunkState
DoneChunking
    withLen Word
len ByteString
bs
        | ByteString -> Bool
S.null ByteString
bs = do
            -- FIXME should this throw an exception if len > 0?
            IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref ChunkState
DoneChunking
            ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
        | Bool
otherwise =
            case ByteString -> Int
S.length ByteString
bs Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
len of
                Ordering
EQ -> ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
bs ChunkState
NeedLenNewline
                Ordering
LT -> ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
bs (ChunkState -> IO ByteString) -> ChunkState -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Word -> ChunkState
HaveLen (Word -> ChunkState) -> Word -> ChunkState
forall a b. (a -> b) -> a -> b
$ Word
len Word -> Word -> Word
forall a. Num a => a -> a -> a
- Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
S.length ByteString
bs)
                Ordering
GT -> do
                    let (ByteString
x, ByteString
y) = Int -> ByteString -> (ByteString, ByteString)
S.splitAt (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
len) ByteString
bs
                    Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
y
                    ByteString -> ChunkState -> IO ByteString
forall {b}. b -> ChunkState -> IO b
yield' ByteString
x ChunkState
NeedLenNewline

    yield' :: b -> ChunkState -> IO b
yield' b
bs ChunkState
mlen = do
        IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref ChunkState
mlen
        b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
bs

    dropCRLF :: IO ()
dropCRLF = do
        bs <- Source -> IO ByteString
readSource Source
src
        case S.uncons bs of
            Maybe (Word8, ByteString)
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just (Word8
w8, ByteString
bs')
                | Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_cr -> ByteString -> IO ()
dropLF ByteString
bs'
                | Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_lf -> Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
bs'
                | Bool
otherwise -> Source -> ByteString -> IO ()
leftoverSource Source
src ByteString
bs

    dropLF :: ByteString -> IO ()
dropLF ByteString
bs =
        case ByteString -> Maybe (Word8, ByteString)
S.uncons ByteString
bs of
            Maybe (Word8, ByteString)
Nothing -> do
                bs2 <- Source -> IO ByteString
readSource' Source
src
                unless (S.null bs2) $ dropLF bs2
            Just (Word8
w8, ByteString
bs') ->
                Source -> ByteString -> IO ()
leftoverSource Source
src (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$
                    if Word8
w8 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
_lf then ByteString
bs' else ByteString
bs

    go :: ChunkState -> IO ByteString
go ChunkState
NeedLen = IO ByteString
getLen
    go ChunkState
NeedLenNewline = IO ()
dropCRLF IO () -> IO ByteString -> IO ByteString
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ByteString
getLen
    go (HaveLen Word
0) = do
        -- Drop the final CRLF
        IO ()
dropCRLF
        IORef ChunkState -> ChunkState -> IO ()
forall a. IORef a -> a -> IO ()
I.writeIORef IORef ChunkState
ref ChunkState
DoneChunking
        ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty
    go (HaveLen Word
len) = do
        bs <- Source -> IO ByteString
readSource Source
src
        withLen len bs
    go ChunkState
DoneChunking = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty

    -- Get the length from the source, and then pass off control to withLen
    getLen :: IO ByteString
getLen = do
        bs <- Source -> IO ByteString
readSource Source
src
        if S.null bs
            then do
                I.writeIORef ref $ assert False $ HaveLen 0
                return S.empty
            else do
                (x, y) <-
                    case S.break (== _lf) bs of
                        (ByteString
x, ByteString
y)
                            | ByteString -> Bool
S.null ByteString
y -> do
                                bs2 <- Source -> IO ByteString
readSource' Source
src
                                return $
                                    if S.null bs2
                                        then (x, y)
                                        else S.break (== _lf) $ bs `S.append` bs2
                            | Bool
otherwise -> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
x, ByteString
y)
                let w =
                        (Word -> Word8 -> Word) -> Word -> ByteString -> Word
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
S.foldl' (\Word
i Word8
c -> Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
16 Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word8 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word8
hexToWord Word8
c)) Word
0 (ByteString -> Word) -> ByteString -> Word
forall a b. (a -> b) -> a -> b
$
                            (Word8 -> Bool) -> ByteString -> ByteString
S.takeWhile Word8 -> Bool
isHexDigit ByteString
x

                let y' = Int -> ByteString -> ByteString
S.drop Int
1 ByteString
y
                y'' <-
                    if S.null y'
                        then readSource src
                        else return y'
                withLen w y''

    hexToWord :: Word8 -> Word8
hexToWord Word8
w
        | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_9 = Word8
w Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
_0
        | Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_F = Word8
w Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
55
        | Bool
otherwise = Word8
w Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
87

isHexDigit :: Word8 -> Bool
isHexDigit :: Word8 -> Bool
isHexDigit Word8
w =
    Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
_0 Bool -> Bool -> Bool
&& Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_9
        Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
_A Bool -> Bool -> Bool
&& Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_F
        Bool -> Bool -> Bool
|| Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
_a Bool -> Bool -> Bool
&& Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
_f