{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.Client.ServerHello (
recvServerHello,
processServerHello13,
) where
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Extension
import Network.TLS.Handshake.Client.Common
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.Random
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types
recvServerHello
:: ClientParams -> Context -> IO [Handshake]
recvServerHello :: ClientParams -> Context -> IO [Handshake]
recvServerHello ClientParams
cparams Context
ctx = do
(sh, hss) <- IO (Handshake, [Handshake])
recvSH
processServerHello cparams ctx sh
processHandshake12 ctx sh
return hss
where
recvSH :: IO (Handshake, [Handshake])
recvSH = do
epkt <- Context -> IO (Either TLSError Packet)
recvPacket12 Context
ctx
case epkt of
Left TLSError
e -> TLSError -> IO (Handshake, [Handshake])
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
e
Right Packet
pkt -> case Packet
pkt of
Alert [(AlertLevel, AlertDescription)]
a -> [(AlertLevel, AlertDescription)] -> IO (Handshake, [Handshake])
forall {m :: * -> *} {a} {a}. (MonadIO m, Show a) => a -> m a
throwAlert [(AlertLevel, AlertDescription)]
a
Handshake (Handshake
h : [Handshake]
hs) -> (Handshake, [Handshake]) -> IO (Handshake, [Handshake])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Handshake
h, [Handshake]
hs)
Packet
_ -> [Char] -> Maybe [Char] -> IO (Handshake, [Handshake])
forall (m :: * -> *) a. MonadIO m => [Char] -> Maybe [Char] -> m a
unexpected (Packet -> [Char]
forall a. Show a => a -> [Char]
show Packet
pkt) ([Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
"handshake")
throwAlert :: a -> m a
throwAlert a
a =
TLSError -> m a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m a) -> TLSError -> m a
forall a b. (a -> b) -> a -> b
$
[Char] -> AlertDescription -> TLSError
Error_Protocol
([Char]
"expecting server hello, got alert : " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
a)
AlertDescription
HandshakeFailure
processServerHello13
:: ClientParams -> Context -> Handshake13 -> IO ()
processServerHello13 :: ClientParams -> Context -> Handshake13 -> IO ()
processServerHello13 ClientParams
cparams Context
ctx (ServerHello13 ServerRandom
serverRan Session
serverSession CipherId
cipher [ExtensionRaw]
shExts) = do
let sh :: Handshake
sh = Version
-> ServerRandom
-> Session
-> CipherId
-> CompressionID
-> [ExtensionRaw]
-> Handshake
ServerHello Version
TLS12 ServerRandom
serverRan Session
serverSession CipherId
cipher CompressionID
0 [ExtensionRaw]
shExts
ClientParams -> Context -> Handshake -> IO ()
processServerHello ClientParams
cparams Context
ctx Handshake
sh
processServerHello13 ClientParams
_ Context
_ Handshake13
h = [Char] -> Maybe [Char] -> IO ()
forall (m :: * -> *) a. MonadIO m => [Char] -> Maybe [Char] -> m a
unexpected (Handshake13 -> [Char]
forall a. Show a => a -> [Char]
show Handshake13
h) ([Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
"server hello")
processServerHello
:: ClientParams -> Context -> Handshake -> IO ()
processServerHello :: ClientParams -> Context -> Handshake -> IO ()
processServerHello ClientParams
cparams Context
ctx (ServerHello Version
rver ServerRandom
serverRan Session
serverSession (CipherId Word16
cid) CompressionID
compression [ExtensionRaw]
shExts) = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
rver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
/= Version
TLS12) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
[Char] -> AlertDescription -> TLSError
Error_Protocol (Version -> [Char]
forall a. Show a => a -> [Char]
show Version
rver [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" is not supported") AlertDescription
IllegalParameter
clientSession <- TLS13State -> Session
tls13stSession (TLS13State -> Session) -> IO TLS13State -> IO Session
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO TLS13State
getTLS13State Context
ctx
chExts <- tls13stSentExtensions <$> getTLS13State ctx
let clientCiphers = Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
cipherAlg <- case findCipher cid clientCiphers of
Maybe Cipher
Nothing -> TLSError -> IO Cipher
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Cipher) -> TLSError -> IO Cipher
forall a b. (a -> b) -> a -> b
$ [Char] -> AlertDescription -> TLSError
Error_Protocol [Char]
"server choose unknown cipher" AlertDescription
IllegalParameter
Just Cipher
alg -> Cipher -> IO Cipher
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Cipher
alg
compressAlg <- case find
((==) compression . compressionID)
(supportedCompressions $ ctxSupported ctx) of
Maybe Compression
Nothing ->
TLSError -> IO Compression
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO Compression) -> TLSError -> IO Compression
forall a b. (a -> b) -> a -> b
$ [Char] -> AlertDescription -> TLSError
Error_Protocol [Char]
"server choose unknown compression" AlertDescription
IllegalParameter
Just Compression
alg -> Compression -> IO Compression
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Compression
alg
ensureNullCompression compression
let checkExt (ExtensionRaw ExtensionID
i ByteString
_)
| ExtensionID
i ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_Cookie = Bool
False
| Bool
otherwise = ExtensionID
i ExtensionID -> [ExtensionID] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [ExtensionID]
chExts
when (any checkExt shExts) $
throwCore $
Error_Protocol "spurious extensions received" UnsupportedExtension
let isHRR = ServerRandom -> Bool
isHelloRetryRequest ServerRandom
serverRan
usingState_ ctx $ do
setTLS13HRR isHRR
when isHRR $
setTLS13Cookie $
lookupAndDecode
EID_Cookie
MsgTServerHello
shExts
Nothing
(\cookie :: Cookie
cookie@(Cookie ByteString
_) -> Cookie -> Maybe Cookie
forall a. a -> Maybe a
Just Cookie
cookie)
setVersion rver
mapM_ processServerExtension shExts
setALPN ctx MsgTServerHello shExts
ver <- usingState_ ctx getVersion
when (ver == TLS12) $ do
usingHState ctx $ setServerHelloParameters rver serverRan cipherAlg compressAlg
let supportedVers = Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ ClientParams -> Supported
clientSupported ClientParams
cparams
when (ver == TLS13) $ do
when (clientSession /= serverSession) $
throwCore $
Error_Protocol
"session is not matched in compatibility mode"
IllegalParameter
when (ver `notElem` supportedVers) $
throwCore $
Error_Protocol
("server version " ++ show ver ++ " is not supported")
ProtocolVersion
when (isDowngraded ver supportedVers serverRan) $
throwCore $
Error_Protocol "version downgrade detected" IllegalParameter
if ver == TLS13
then do
usingState_ ctx $ setSession serverSession
processRecordSizeLimit cparams ctx shExts True
enableMyRecordLimit ctx
enablePeerRecordLimit ctx
updateContext13 ctx cipherAlg
else do
let resumingSession = case ClientParams -> [(ByteString, SessionData)]
clientSessions ClientParams
cparams of
(ByteString
_, SessionData
sessionData) : [(ByteString, SessionData)]
_ ->
if Session
serverSession Session -> Session -> Bool
forall a. Eq a => a -> a -> Bool
== Session
clientSession then SessionData -> Maybe SessionData
forall a. a -> Maybe a
Just SessionData
sessionData else Maybe SessionData
forall a. Maybe a
Nothing
[(ByteString, SessionData)]
_ -> Maybe SessionData
forall a. Maybe a
Nothing
usingState_ ctx $ do
setSession serverSession
setTLS12SessionResuming $ isJust resumingSession
processRecordSizeLimit cparams ctx shExts False
updateContext12 ctx shExts resumingSession
processServerHello ClientParams
_ Context
_ Handshake
p = [Char] -> Maybe [Char] -> IO ()
forall (m :: * -> *) a. MonadIO m => [Char] -> Maybe [Char] -> m a
unexpected (Handshake -> [Char]
forall a. Show a => a -> [Char]
show Handshake
p) ([Char] -> Maybe [Char]
forall a. a -> Maybe a
Just [Char]
"server hello")
processServerExtension :: ExtensionRaw -> TLSSt ()
processServerExtension :: ExtensionRaw -> TLSSt ()
processServerExtension (ExtensionRaw ExtensionID
extID ByteString
content)
| ExtensionID
extID ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_SecureRenegotiation = do
VerifyData cvd <- Role -> TLSSt VerifyData
getVerifyData Role
ClientRole
VerifyData svd <- getVerifyData ServerRole
let bs = SecureRenegotiation -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (SecureRenegotiation -> ByteString)
-> SecureRenegotiation -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> SecureRenegotiation
SecureRenegotiation ByteString
cvd ByteString
svd
unless (bs == content) $
throwError $
Error_Protocol "server secure renegotiation data not matching" HandshakeFailure
| ExtensionID
extID ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_SupportedVersions = case MessageType -> ByteString -> Maybe SupportedVersions
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTServerHello ByteString
content of
Just (SupportedVersionsServerHello Version
ver) -> Version -> TLSSt ()
setVersion Version
ver
Maybe SupportedVersions
_ -> () -> TLSSt ()
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| ExtensionID
extID ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_KeyShare = do
hrr <- TLSSt Bool
getTLS13HRR
let msgt = if Bool
hrr then MessageType
MsgTHelloRetryRequest else MessageType
MsgTServerHello
setTLS13KeyShare $ extensionDecode msgt content
| ExtensionID
extID ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_PreSharedKey =
Maybe PreSharedKey -> TLSSt ()
setTLS13PreSharedKey (Maybe PreSharedKey -> TLSSt ()) -> Maybe PreSharedKey -> TLSSt ()
forall a b. (a -> b) -> a -> b
$ MessageType -> ByteString -> Maybe PreSharedKey
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTServerHello ByteString
content
| ExtensionID
extID ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_SessionTicket = ByteString -> TLSSt ()
setTLS12SessionTicket ByteString
""
processServerExtension ExtensionRaw
_ = () -> TLSSt ()
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
updateContext13 :: Context -> Cipher -> IO ()
updateContext13 :: Context -> Cipher -> IO ()
updateContext13 Context
ctx Cipher
cipherAlg = do
established <- Context -> IO Established
ctxEstablished Context
ctx
eof <- ctxEOF ctx
when (established == Established && not eof) $
throwCore $
Error_Protocol
"renegotiation to TLS 1.3 or later is not allowed"
ProtocolVersion
failOnEitherError $ usingHState ctx $ setHelloParameters13 cipherAlg
updateContext12 :: Context -> [ExtensionRaw] -> Maybe SessionData -> IO ()
updateContext12 :: Context -> [ExtensionRaw] -> Maybe SessionData -> IO ()
updateContext12 Context
ctx [ExtensionRaw]
shExts Maybe SessionData
resumingSession = do
ems <- Context -> Version -> MessageType -> [ExtensionRaw] -> IO Bool
forall (m :: * -> *).
MonadIO m =>
Context -> Version -> MessageType -> [ExtensionRaw] -> m Bool
processExtendedMainSecret Context
ctx Version
TLS12 MessageType
MsgTServerHello [ExtensionRaw]
shExts
case resumingSession of
Maybe SessionData
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just SessionData
sessionData -> do
let emsSession :: Bool
emsSession = SessionFlag
SessionEMS SessionFlag -> [SessionFlag] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` SessionData -> [SessionFlag]
sessionFlags SessionData
sessionData
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
ems Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Bool
emsSession) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
let err :: [Char]
err = [Char]
"server resumes a session which is not EMS consistent"
in TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> AlertDescription -> TLSError
Error_Protocol [Char]
err AlertDescription
HandshakeFailure
let mainSecret :: ByteString
mainSecret = SessionData -> ByteString
sessionSecret SessionData
sessionData
Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ Version -> Role -> ByteString -> HandshakeM ()
setMainSecret Version
TLS12 Role
ClientRole ByteString
mainSecret
Context -> MainSecret -> IO ()
forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MainSecret
MainSecret ByteString
mainSecret)
processRecordSizeLimit
:: ClientParams -> Context -> [ExtensionRaw] -> Bool -> IO ()
processRecordSizeLimit :: ClientParams -> Context -> [ExtensionRaw] -> Bool -> IO ()
processRecordSizeLimit ClientParams
cparams Context
ctx [ExtensionRaw]
shExts Bool
tls13 = do
let mmylim :: Maybe Int
mmylim = Limit -> Maybe Int
limitRecordSize (Limit -> Maybe Int) -> Limit -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Shared -> Limit
sharedLimit (Shared -> Limit) -> Shared -> Limit
forall a b. (a -> b) -> a -> b
$ ClientParams -> Shared
clientShared ClientParams
cparams
case Maybe Int
mmylim of
Maybe Int
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just Int
mylim -> do
ExtensionID
-> MessageType
-> [ExtensionRaw]
-> IO ()
-> (RecordSizeLimit -> IO ())
-> IO ()
forall a b.
Extension a =>
ExtensionID
-> MessageType -> [ExtensionRaw] -> IO b -> (a -> IO b) -> IO b
lookupAndDecodeAndDo
ExtensionID
EID_RecordSizeLimit
MessageType
MsgTClientHello
[ExtensionRaw]
shExts
(() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
(Context -> Bool -> RecordSizeLimit -> IO ()
setPeerRecordSizeLimit Context
ctx Bool
tls13)
ack <- Context -> IO Bool
checkPeerRecordLimit Context
ctx
when (ack && tls13) $ setMyRecordLimit ctx $ Just (mylim - 1)