diff --git a/lib/mbedtls_wrapper/library/ssl_lib.c b/lib/mbedtls_wrapper/library/ssl_lib.c index ae6f02f3..187fc9f0 100644 --- a/lib/mbedtls_wrapper/library/ssl_lib.c +++ b/lib/mbedtls_wrapper/library/ssl_lib.c @@ -87,6 +87,9 @@ int SSL_want_nothing(const SSL *ssl) { SSL_ASSERT1(ssl); + if (ssl->err) + return 1; + return (SSL_want(ssl) == SSL_NOTHING); } @@ -97,6 +100,9 @@ int SSL_want_read(const SSL *ssl) { SSL_ASSERT1(ssl); + if (ssl->err) + return 0; + return (SSL_want(ssl) == SSL_READING); } @@ -107,6 +113,9 @@ int SSL_want_write(const SSL *ssl) { SSL_ASSERT1(ssl); + if (ssl->err) + return 0; + return (SSL_want(ssl) == SSL_WRITING); } diff --git a/lib/mbedtls_wrapper/platform/ssl_pm.c b/lib/mbedtls_wrapper/platform/ssl_pm.c index 0fef1889..a132971d 100755 --- a/lib/mbedtls_wrapper/platform/ssl_pm.c +++ b/lib/mbedtls_wrapper/platform/ssl_pm.c @@ -306,11 +306,15 @@ int ssl_pm_shutdown(SSL *ssl) ret = mbedtls_ssl_close_notify(&ssl_pm->ssl); if (ret) { SSL_DEBUG(SSL_PLATFORM_ERROR_LEVEL, "mbedtls_ssl_close_notify() return -0x%x", -ret); - ret = -1; + if (ret == MBEDTLS_ERR_NET_CONN_RESET) + ssl->err = SSL_ERROR_SYSCALL; + ret = -1; /* OpenSSL: "Call SSL_get_error with the return value to find the reason */ } else { struct x509_pm *x509_pm = (struct x509_pm *)ssl->session->peer->x509_pm; x509_pm->ex_crt = NULL; + ret = 1; /* OpenSSL: "The shutdown was successfully completed" + ...0 means retry */ } return ret; @@ -330,6 +334,8 @@ int ssl_pm_read(SSL *ssl, void *buffer, int len) ret = mbedtls_ssl_read(&ssl_pm->ssl, buffer, len); if (ret < 0) { SSL_DEBUG(SSL_PLATFORM_ERROR_LEVEL, "mbedtls_ssl_read() return -0x%x", -ret); + if (ret == MBEDTLS_ERR_NET_CONN_RESET) + ssl->err = SSL_ERROR_SYSCALL; ret = -1; } @@ -343,6 +349,8 @@ int ssl_pm_send(SSL *ssl, const void *buffer, int len) ret = mbedtls_ssl_write(&ssl_pm->ssl, buffer, len); if (ret < 0) { + if (ret == MBEDTLS_ERR_NET_CONN_RESET) + ssl->err = SSL_ERROR_SYSCALL; SSL_DEBUG(SSL_PLATFORM_ERROR_LEVEL, "mbedtls_ssl_write() return -0x%x", -ret); ret = -1; } @@ -702,6 +710,9 @@ void SSL_get0_alpn_selected(const SSL *ssl, const unsigned char **data, const char *alp = mbedtls_ssl_get_alpn_protocol(&((struct ssl_pm *)(ssl->ssl_pm))->ssl); *data = (const unsigned char *)alp; - *len = strlen(alp); + if (alp) + *len = strlen(alp); + else + *len = 0; }