diff --git a/include/libwebsockets/lws-jose.h b/include/libwebsockets/lws-jose.h index c780c0e2b..54b6690c5 100644 --- a/include/libwebsockets/lws-jose.h +++ b/include/libwebsockets/lws-jose.h @@ -120,6 +120,7 @@ struct lws_jose { struct lws_jws_recpient recipient[LWS_JWS_MAX_RECIPIENTS]; char typ[32]; + char edone[LWS_COUNT_JOSE_HDR_ELEMENTS]; /* information from the protected header part */ const struct lws_jose_jwe_alg *alg; diff --git a/lib/jose/jws/jose.c b/lib/jose/jws/jose.c index 4007865bf..ca86f88a2 100644 --- a/lib/jose/jws/jose.c +++ b/lib/jose/jws/jose.c @@ -302,7 +302,6 @@ lws_jws_jose_cb(struct lejp_ctx *ctx, char reason) if (!args->is_jwe) return -1; /* Ephemeral key... this JSON subsection is actually a JWK */ - lwsl_err("LJJHI_EPK\n"); break; case LJJHI_APU: /* Additional arg for JWE ECDH */ @@ -361,17 +360,22 @@ append_string: *args->temp_len -= ctx->npos; args->jose->e[ctx->path_match - 1].len += ctx->npos; - if (reason == LEJPCB_VAL_STR_END) { + if (reason == LEJPCB_VAL_STR_END && + (int)args->jose->e[ctx->path_match - 1].len && + !args->jose->edone[ctx->path_match - 1]) { n = lws_b64_decode_string_len( (const char *)args->jose->e[ctx->path_match - 1].buf, (int)args->jose->e[ctx->path_match - 1].len, (char *)args->jose->e[ctx->path_match - 1].buf, (int)args->jose->e[ctx->path_match - 1].len + 1); if (n < 0) { - lwsl_err("%s: b64 decode failed\n", __func__); + lwsl_err("%s: b64 decode failed len %d\n", __func__, + (int)args->jose->e[ctx->path_match - 1].len); + return -1; } + args->jose->edone[ctx->path_match - 1] = 1; args->temp -= (int)args->jose->e[ctx->path_match - 1].len - n - 1; *args->temp_len += (int)args->jose->e[ctx->path_match - 1].len - n - 1; diff --git a/lib/jose/jws/jws.c b/lib/jose/jws/jws.c index 6364c6be7..d8d7e154e 100644 --- a/lib/jose/jws/jws.c +++ b/lib/jose/jws/jws.c @@ -339,20 +339,24 @@ lws_jws_compact_decode(const char *in, int len, struct lws_jws_map *map, return -1; while (m < blocks) { - n = lws_b64_decode_string_len(map_b64->buf[m], (int)map_b64->len[m], - out, *out_len); - if (n < 0) { - lwsl_err("%s: b64 decode failed\n", __func__); - return -1; - } - /* replace the map entry with the decoded content */ - if (n) - map->buf[m] = out; - else - map->buf[m] = NULL; - map->len[m++] = (unsigned int)n; - out += n; - *out_len -= n; + if ((int)map_b64->len[m]) { + n = lws_b64_decode_string_len(map_b64->buf[m], (int)map_b64->len[m], + out, *out_len); + if (n < 0) { + lwsl_err("%s: b64 decode failed len %d\n", + __func__, (int)map_b64->len[m]); + return -1; + } + /* replace the map entry with the decoded content */ + if (n) + map->buf[m] = out; + else + map->buf[m] = NULL; + map->len[m++] = (unsigned int)n; + out += n; + *out_len -= n; + } else + m++; if (*out_len < 1) return -1; @@ -368,15 +372,20 @@ lws_jws_compact_decode_map(struct lws_jws_map *map_b64, struct lws_jws_map *map, int n, m = 0; for (n = 0; n < LWS_JWS_MAX_COMPACT_BLOCKS; n++) { - n = lws_b64_decode_string_len(map_b64->buf[m], (int)map_b64->len[m], - out, *out_len); - if (n < 0) { - lwsl_err("%s: b64 decode failed\n", __func__); - return -1; - } - /* replace the map entry with the decoded content */ - map->buf[m] = out; - map->len[m++] = (unsigned int)n; + if ((int)map_b64->len[m]) { + n = lws_b64_decode_string_len(map_b64->buf[m], (int)map_b64->len[m], + out, *out_len); + if (n < 0) { + lwsl_err("%s: b64 decode failed len %d\n", + __func__, (int)map_b64->len[m]); + + return -1; + } + /* replace the map entry with the decoded content */ + map->buf[m] = out; + map->len[m++] = (unsigned int)n; + } else + m++; out += n; *out_len -= n; diff --git a/lib/misc/base64-decode.c b/lib/misc/base64-decode.c index 0776509f7..458a6600d 100644 --- a/lib/misc/base64-decode.c +++ b/lib/misc/base64-decode.c @@ -112,6 +112,7 @@ lws_b64_decode_stateful(struct lws_b64state *s, const char *in, size_t *in_len, { const char *orig_in = in, *end_in = in + *in_len; uint8_t *orig_out = out, *end_out = out + *out_size; + int equals = 0; while (in < end_in && *in && out + 4 < end_out) { @@ -122,12 +123,39 @@ lws_b64_decode_stateful(struct lws_b64state *s, const char *in, size_t *in_len, s->c = 0; while (in < end_in && *in && !v) { s->c = v = (unsigned char)*in++; + + if (v == '\x0a') { + v = 0; + continue; + } + + if (v == '=') { + equals++; + v = 0; + continue; + } + + /* Sanity check this is part of the charset */ + + if ((v < '0' || v > '9') && + (v < 'A' || v > 'Z') && + (v < 'a' || v > 'z') && + v != '-' && v != '+' && v != '_' && v != '/') { + lwsl_err("%s: bad base64 0x%02X '%c' @+%d\n", __func__, v, v, lws_ptr_diff(in, orig_in)); + return -1; + } + + if (equals) { + lwsl_err("%s: non = after =\n", __func__); + return -1; + } + /* support the url base64 variant too */ if (v == '-') s->c = v = '+'; if (v == '_') s->c = v = '/'; - v = (uint8_t)((v < 43 || v > 122) ? 0 : decode[v - 43]); + v = (uint8_t)decode[v - 43]; if (v) v = (uint8_t)((v == '$') ? 0 : v - 61); } @@ -150,14 +178,11 @@ lws_b64_decode_stateful(struct lws_b64state *s, const char *in, size_t *in_len, * bytes." (wikipedia) */ - if ((in >= end_in || !*in) && s->c == '=') - s->len--; - - if (s->len >= 2) + if (s->len >= 2 || equals > 1) *out++ = (uint8_t)(s->quad[0] << 2 | s->quad[1] >> 4); - if (s->len >= 3) + if (s->len >= 3 || equals) *out++ = (uint8_t)(s->quad[1] << 4 | s->quad[2] >> 2); - if (s->len >= 4) + if (s->len >= 4 && !equals) *out++ = (uint8_t)(((s->quad[2] << 6) & 0xc0) | s->quad[3]); s->done += s->len - 1; @@ -192,11 +217,15 @@ _lws_b64_decode_string(const char *in, int in_len, char *out, size_t out_size) il = strlen(in); lws_b64_decode_state_init(&state); - lws_b64_decode_stateful(&state, in, &il, (uint8_t *)out, &ol, 1); - - if (!il) + if (lws_b64_decode_stateful(&state, in, &il, (uint8_t *)out, &ol, 1) < 0) + /* pass on the failure */ return 0; + if ((int)il != in_len) { + lwsl_err("%s: base64 must end at end of input\n", __func__); + return 0; + } + return ol; } @@ -209,7 +238,9 @@ lws_b64_decode_string(const char *in, char *out, int out_size) int lws_b64_decode_string_len(const char *in, int in_len, char *out, int out_size) { - return (int)_lws_b64_decode_string(in, in_len, out, (unsigned int)out_size); + size_t s = _lws_b64_decode_string(in, in_len, out, (unsigned int)out_size); + + return !s ? -1 : (int)s; } #if 0 diff --git a/lib/tls/mbedtls/mbedtls-server.c b/lib/tls/mbedtls/mbedtls-server.c index ca703c5a2..36cb9fd00 100644 --- a/lib/tls/mbedtls/mbedtls-server.c +++ b/lib/tls/mbedtls/mbedtls-server.c @@ -148,7 +148,7 @@ lws_tls_server_certs_load(struct lws_vhost *vhost, struct lws *wsi, } if (lws_tls_alloc_pem_to_der_file(vhost->context, cert, mem_cert, mem_cert_len, &p, &flen)) { - lwsl_err("couldn't find cert file %s\n", cert); + lwsl_err("couldn't load cert file %s\n", cert); return 1; } diff --git a/lib/tls/tls.c b/lib/tls/tls.c index 29a7a5d5d..97481ac25 100644 --- a/lib/tls/tls.c +++ b/lib/tls/tls.c @@ -361,7 +361,7 @@ lws_tls_alloc_pem_to_der_file(struct lws_context *context, const char *filename, if (!filename) { /* we don't know if it's in const memory... alloc the output */ - pem = lws_malloc(((size_t)inlen * 3) / 4, "alloc_der"); + pem = lws_malloc(((size_t)(inlen + 3) * 3) / 4, "alloc_der"); if (!pem) { lwsl_err("a\n"); return 1; @@ -408,8 +408,15 @@ lws_tls_alloc_pem_to_der_file(struct lws_context *context, const char *filename, n = lws_ptr_diff(q, p); if (n == -1) /* coverity */ goto bail; - *amount = (unsigned int)lws_b64_decode_string_len((char *)p, n, - (char *)pem, (int)(long long)len); + + n = lws_b64_decode_string_len((char *)p, n, + (char *)pem, (int)(long long)len); + if (n < 0) { + lwsl_err("%s: base64 pem decode failed\n", __func__); + goto bail; + } + + *amount = (unsigned int)n; *buf = (uint8_t *)pem; return 0;