@ -30,8 +30,6 @@
@@ -30,8 +30,6 @@
*/
# include <stdio.h>
# include <stdlib.h> /* malloc() */
# include <string.h> /* strncpy() */
# include <sys/socket.h>
# include "tls.h"
# define TLS_HEADER_LEN 5
@ -42,19 +40,20 @@
@@ -42,19 +40,20 @@
# define MIN(X, Y) ((X) < (Y) ? (X) : (Y))
# endif
static int parse_extensions ( const char * , size_t , char * * ) ;
static int parse_server_name_extension ( const char * , size_t , char * * ) ;
const char tls_alert [ ] = {
0x15 , /* TLS Alert */
0x03 , 0x01 , /* TLS version */
0x00 , 0x02 , /* Payload length */
0x02 , 0x28 , /* Fatal, handshake failure */
struct TLSProtocol {
int use_alpn ;
char * * sni_hostname_list ;
char * * alpn_protocol_list ;
} ;
/* Parse a TLS packet for the Server Name Indication extension in the client
* hello handshake , returning the first servername found ( pointer to static
* array )
static int parse_extensions ( const struct TLSProtocol * , const char * , size_t ) ;
static int parse_server_name_extension ( const struct TLSProtocol * , const char * , size_t ) ;
static int parse_alpn_extension ( const struct TLSProtocol * , const char * , size_t ) ;
static int has_match ( char * * , const char * , size_t ) ;
/* Parse a TLS packet for the Server Name Indication and ALPN extension in the client
* hello handshake , returning a status code
*
* Returns :
* > = 0 - length of the hostname and updates * hostname
@ -62,35 +61,20 @@ const char tls_alert[] = {
@@ -62,35 +61,20 @@ const char tls_alert[] = {
* - 1 - Incomplete request
* - 2 - No Host header included in this request
* - 3 - Invalid hostname pointer
* - 4 - malloc failure
* < - 4 - Invalid TLS client hello
*/
int
parse_tls_header ( const char * data , size_t data_len , char * * hostname ) {
parse_tls_header ( const struct TLSProtocol * tls_data , const char * data , size_t data_len ) {
char tls_content_type ;
char tls_version_major ;
char tls_version_minor ;
size_t pos = TLS_HEADER_LEN ;
size_t len ;
if ( hostname = = NULL )
return - 3 ;
/* Check that our TCP payload is at least large enough for a TLS header */
if ( data_len < TLS_HEADER_LEN )
return - 1 ;
/* SSL 2.0 compatible Client Hello
*
* High bit of first byte ( length ) and content type is Client Hello
*
* See RFC5246 Appendix E .2
*/
if ( data [ 0 ] & 0x80 & & data [ 2 ] = = 1 ) {
if ( verbose ) fprintf ( stderr , " Received SSL 2.0 Client Hello which can not support SNI. \n " ) ;
return - 2 ;
}
tls_content_type = data [ 0 ] ;
if ( tls_content_type ! = TLS_HANDSHAKE_CONTENT_TYPE ) {
if ( verbose ) fprintf ( stderr , " Request did not begin with TLS handshake. \n " ) ;
@ -100,7 +84,7 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) {
@@ -100,7 +84,7 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) {
tls_version_major = data [ 1 ] ;
tls_version_minor = data [ 2 ] ;
if ( tls_version_major < 3 ) {
if ( verbose ) fprintf ( stderr , " Received SSL %d.%d handshake which which can not support SNI . \n " ,
if ( verbose ) fprintf ( stderr , " Received SSL %d.%d handshake which cannot be parsed . \n " ,
tls_version_major , tls_version_minor ) ;
return - 2 ;
@ -108,7 +92,7 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) {
@@ -108,7 +92,7 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) {
/* TLS record length */
len = ( ( unsigned char ) data [ 3 ] < < 8 ) +
( unsigned char ) data [ 4 ] + TLS_HEADER_LEN ;
( unsigned char ) data [ 4 ] + TLS_HEADER_LEN ;
data_len = MIN ( data_len , len ) ;
/* Check we received entire TLS record length */
@ -167,30 +151,75 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) {
@@ -167,30 +151,75 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) {
if ( pos + len > data_len )
return - 5 ;
return parse_extensions ( data + pos , len , hostname ) ;
return parse_extensions ( tls_data , data + pos , len ) ;
}
int
parse_extensions ( const char * data , size_t data_len , char * * hostname ) {
static int
parse_extensions ( const struct TLSProtocol * tls_data , const char * data , size_t data_len ) {
size_t pos = 0 ;
size_t len ;
int last_matched = 0 ;
if ( tls_data = = NULL )
return - 3 ;
/* Parse each 4 bytes for the extension header */
while ( pos + 4 < = data_len ) {
/* Extension Length */
len = ( ( unsigned char ) data [ pos + 2 ] < < 8 ) +
( unsigned char ) data [ pos + 3 ] ;
len = ( ( unsigned char ) data [ pos + 2 ] < < 8 ) +
( unsigned char ) data [ pos + 3 ] ;
if ( pos + 4 + len > data_len )
return - 5 ;
size_t extension_type = ( ( unsigned char ) data [ pos ] < < 8 ) +
( unsigned char ) data [ pos + 1 ] ;
/* Check if it's a server name extension */
if ( data [ pos ] = = 0x00 & & data [ pos + 1 ] = = 0x00 ) {
/* There can be only one extension of each type, so we break
our state and move p to beinnging of the extension here */
if ( pos + 4 + len > data_len )
return - 5 ;
return parse_server_name_extension ( data + pos + 4 , len , hostname ) ;
/* There can be only one extension of each type, so we break
our state and move pos to beginning of the extension here */
if ( tls_data - > use_alpn = = 2 ) {
/* we want BOTH alpn and sni to match */
if ( extension_type = = 0x00 ) { /* Server Name */
if ( parse_server_name_extension ( tls_data , data + pos + 4 , len ) ) {
/* SNI matched */
if ( last_matched ) {
/* this is only true if ALPN matched, so return true */
return last_matched ;
} else {
/* otherwise store that SNI matched */
last_matched = 1 ;
}
} else {
// both can't match
return - 2 ;
}
} else if ( extension_type = = 0x10 ) { /* ALPN */
if ( parse_alpn_extension ( tls_data , data + pos + 4 , len ) ) {
/* ALPN matched */
if ( last_matched ) {
/* this is only true if SNI matched, so return true */
return last_matched ;
} else {
/* otherwise store that ALPN matched */
last_matched = 1 ;
}
} else {
// both can't match
return - 2 ;
}
}
} else if ( extension_type = = 0x00 & & tls_data - > use_alpn = = 0 ) { /* Server Name */
return parse_server_name_extension ( tls_data , data + pos + 4 , len ) ;
} else if ( extension_type = = 0x10 & & tls_data - > use_alpn = = 1 ) { /* ALPN */
return parse_alpn_extension ( tls_data , data + pos + 4 , len ) ;
}
pos + = 4 + len ; /* Advance to the next extension header */
}
/* Check we ended where we expected to */
if ( pos ! = data_len )
return - 5 ;
@ -198,32 +227,25 @@ parse_extensions(const char *data, size_t data_len, char **hostname) {
@@ -198,32 +227,25 @@ parse_extensions(const char *data, size_t data_len, char **hostname) {
return - 2 ;
}
int
parse_server_name_extension ( const char * data , size_t data_len ,
char * * hostname ) {
static int
parse_server_name_extension ( const struct TLSProtocol * tls_data , const char * data , size_t data_len ) {
size_t pos = 2 ; /* skip server name list length */
size_t len ;
while ( pos + 3 < data_len ) {
len = ( ( unsigned char ) data [ pos + 1 ] < < 8 ) +
( unsigned char ) data [ pos + 2 ] ;
( unsigned char ) data [ pos + 2 ] ;
if ( pos + 3 + len > data_len )
return - 5 ;
switch ( data [ pos ] ) { /* name type */
case 0x00 : /* host_name */
* hostname = malloc ( len + 1 ) ;
if ( * hostname = = NULL ) {
if ( verbose ) fprintf ( stderr , " malloc() failure \n " ) ;
return - 4 ;
if ( has_match ( tls_data - > sni_hostname_list , data + pos + 3 , len ) ) {
return len ;
} else {
return - 2 ;
}
strncpy ( * hostname , data + pos + 3 , len ) ;
( * hostname ) [ len ] = ' \0 ' ;
return len ;
default :
if ( verbose ) fprintf ( stderr , " Unknown server name extension name type: %d \n " ,
data [ pos ] ) ;
@ -236,3 +258,70 @@ parse_server_name_extension(const char *data, size_t data_len,
@@ -236,3 +258,70 @@ parse_server_name_extension(const char *data, size_t data_len,
return - 2 ;
}
static int
parse_alpn_extension ( const struct TLSProtocol * tls_data , const char * data , size_t data_len ) {
size_t pos = 2 ;
size_t len ;
while ( pos + 1 < data_len ) {
len = ( unsigned char ) data [ pos ] ;
if ( pos + 1 + len > data_len )
return - 5 ;
if ( len > 0 & & has_match ( tls_data - > alpn_protocol_list , data + pos + 1 , len ) ) {
return len ;
} else if ( len > 0 ) {
if ( verbose ) fprintf ( stderr , " Unknown ALPN name: %.*s \n " , ( int ) len , data + pos + 1 ) ;
}
pos + = 1 + len ;
}
/* Check we ended where we expected to */
if ( pos ! = data_len )
return - 5 ;
return - 2 ;
}
static int
has_match ( char * * list , const char * name , size_t name_len ) {
char * * item ;
for ( item = list ; * item ; item + + ) {
if ( verbose ) fprintf ( stderr , " matching [%.*s] with [%s] \n " , ( int ) name_len , name , * item ) ;
if ( ! strncmp ( * item , name , name_len ) ) {
return 1 ;
}
}
return 0 ;
}
struct TLSProtocol *
new_tls_data ( ) {
struct TLSProtocol * tls_data = malloc ( sizeof ( struct TLSProtocol ) ) ;
if ( tls_data ! = NULL ) {
tls_data - > use_alpn = - 1 ;
}
return tls_data ;
}
struct TLSProtocol *
tls_data_set_list ( struct TLSProtocol * tls_data , int alpn , char * * list ) {
if ( alpn ) {
tls_data - > alpn_protocol_list = list ;
if ( tls_data - > use_alpn = = 0 )
tls_data - > use_alpn = 2 ;
else
tls_data - > use_alpn = 1 ;
} else {
tls_data - > sni_hostname_list = list ;
if ( tls_data - > use_alpn = = 1 )
tls_data - > use_alpn = 2 ;
else
tls_data - > use_alpn = 0 ;
}
return tls_data ;
}