diff --git a/include/libwebsockets/lws-sha1-base64.h b/include/libwebsockets/lws-sha1-base64.h index a64c4da80..0438aa9c9 100644 --- a/include/libwebsockets/lws-sha1-base64.h +++ b/include/libwebsockets/lws-sha1-base64.h @@ -90,5 +90,20 @@ lws_b64_decode_string(const char *in, char *out, int out_size); */ LWS_VISIBLE LWS_EXTERN int lws_b64_decode_string_len(const char *in, int in_len, char *out, int out_size); + +struct lws_b64state { + unsigned char quad[4]; + size_t done; + size_t len; + int i; + int c; +}; + +LWS_VISIBLE LWS_EXTERN void +lws_b64_decode_state_init(struct lws_b64state *state); + +LWS_VISIBLE LWS_EXTERN int +lws_b64_decode_stateful(struct lws_b64state *s, const char *in, size_t *in_len, + uint8_t *out, size_t *out_size, int final); ///@} diff --git a/lib/misc/base64-decode.c b/lib/misc/base64-decode.c index 010cc7d71..3262685ab 100644 --- a/lib/misc/base64-decode.c +++ b/lib/misc/base64-decode.c @@ -103,6 +103,79 @@ lws_b64_encode_string_url(const char *in, int in_len, char *out, int out_size) return _lws_b64_encode_string(encode_url, in, in_len, out, out_size); } + +void +lws_b64_decode_state_init(struct lws_b64state *state) +{ + memset(state, 0, sizeof(*state)); +} + +int +lws_b64_decode_stateful(struct lws_b64state *s, const char *in, size_t *in_len, + uint8_t *out, size_t *out_size, int final) +{ + const char *orig_in = in, *end_in = in + *in_len; + uint8_t *orig_out = out, *end_out = out + *out_size; + + while (in < end_in && *in && out + 4 < end_out) { + + for (; s->i < 4 && in < end_in && *in; s->i++) { + uint8_t v; + + v = 0; + s->c = 0; + while (in < end_in && *in && !v) { + s->c = v = *in++; + /* support the url base64 variant too */ + if (v == '-') + s->c = v = '+'; + if (v == '_') + s->c = v = '/'; + v = (v < 43 || v > 122) ? 0 : decode[v - 43]; + if (v) + v = (v == '$') ? 0 : v - 61; + } + if (s->c) { + s->len++; + if (v) + s->quad[s->i] = v - 1; + } else + s->quad[s->i] = 0; + } + + if (s->i != 4 && !final) + continue; + + s->i = 0; + + /* + * "The '==' sequence indicates that the last group contained + * only one byte, and '=' indicates that it contained two + * bytes." (wikipedia) + */ + + if ((in >= end_in || !*in) && s->c == '=') + s->len--; + + if (s->len >= 2) + *out++ = s->quad[0] << 2 | s->quad[1] >> 4; + if (s->len >= 3) + *out++ = s->quad[1] << 4 | s->quad[2] >> 2; + if (s->len >= 4) + *out++ = ((s->quad[2] << 6) & 0xc0) | s->quad[3]; + + s->done += s->len - 1; + s->len = 0; + } + + *out = '\0'; + *in_len = in - orig_in; + *out_size = out - orig_out; + + return 0; +} + + /* * returns length of decoded string in out, or -1 if out was too small * according to out_size @@ -114,65 +187,19 @@ lws_b64_encode_string_url(const char *in, int in_len, char *out, int out_size) static int _lws_b64_decode_string(const char *in, int in_len, char *out, int out_size) { - int len, i, c = 0, done = 0; - unsigned char v, quad[4]; + struct lws_b64state state; + size_t il = in_len, ol = out_size; - while (in_len && *in) { + if (in_len == -1) + il = in_len = strlen(in); - len = 0; - for (i = 0; i < 4 && in_len && *in; i++) { + lws_b64_decode_state_init(&state); + lws_b64_decode_stateful(&state, in, &il, (uint8_t *)out, &ol, 1); - v = 0; - c = 0; - while (in_len && *in && !v) { - c = v = *in++; - in_len--; - /* support the url base64 variant too */ - if (v == '-') - c = v = '+'; - if (v == '_') - c = v = '/'; - v = (v < 43 || v > 122) ? 0 : decode[v - 43]; - if (v) - v = (v == '$') ? 0 : v - 61; - } - if (c) { - len++; - if (v) - quad[i] = v - 1; - } else - quad[i] = 0; - } + if (!il) + return 0; - if (out_size < (done + len + 1)) - /* out buffer is too small */ - return -1; - - /* - * "The '==' sequence indicates that the last group contained - * only one byte, and '=' indicates that it contained two - * bytes." (wikipedia) - */ - - if ((!in_len || !*in) && c == '=') - len--; - - if (len >= 2) - *out++ = quad[0] << 2 | quad[1] >> 4; - if (len >= 3) - *out++ = quad[1] << 4 | quad[2] >> 2; - if (len >= 4) - *out++ = ((quad[2] << 6) & 0xc0) | quad[3]; - - done += len - 1; - } - - if (done + 1 >= out_size) - return -1; - - *out = '\0'; - - return done; + return ol; } LWS_VISIBLE int @@ -188,31 +215,35 @@ lws_b64_decode_string_len(const char *in, int in_len, char *out, int out_size) } #if 0 +static const char * const plaintext[] = { + "any carnal pleasure.", + "any carnal pleasure", + "any carnal pleasur", + "any carnal pleasu", + "any carnal pleas", + "Admin:kloikloi" +}; +static const char * const coded[] = { + "YW55IGNhcm5hbCBwbGVhc3VyZS4=", + "YW55IGNhcm5hbCBwbGVhc3VyZQ==", + "YW55IGNhcm5hbCBwbGVhc3Vy", + "YW55IGNhcm5hbCBwbGVhc3U=", + "YW55IGNhcm5hbCBwbGVhcw==", + "QWRtaW46a2xvaWtsb2k=" +}; + int lws_b64_selftest(void) { char buf[64]; unsigned int n, r = 0; unsigned int test; - /* examples from https://en.wikipedia.org/wiki/Base64 */ - static const char * const plaintext[] = { - "any carnal pleasure.", - "any carnal pleasure", - "any carnal pleasur", - "any carnal pleasu", - "any carnal pleas", - "Admin:kloikloi" - }; - static const char * const coded[] = { - "YW55IGNhcm5hbCBwbGVhc3VyZS4=", - "YW55IGNhcm5hbCBwbGVhc3VyZQ==", - "YW55IGNhcm5hbCBwbGVhc3Vy", - "YW55IGNhcm5hbCBwbGVhc3U=", - "YW55IGNhcm5hbCBwbGVhcw==", - "QWRtaW46a2xvaWtsb2k=" - }; - for (test = 0; test < sizeof plaintext / sizeof(plaintext[0]); test++) { + lwsl_notice("%s\n", __func__); + + /* examples from https://en.wikipedia.org/wiki/Base64 */ + + for (test = 0; test < (int)LWS_ARRAY_SIZE(plaintext); test++) { buf[sizeof(buf) - 1] = '\0'; n = lws_b64_encode_string(plaintext[test], @@ -226,15 +257,20 @@ lws_b64_selftest(void) buf[sizeof(buf) - 1] = '\0'; n = lws_b64_decode_string(coded[test], buf, sizeof buf); if (n != strlen(plaintext[test]) || - strcmp(buf, plaintext[test])) { + strcmp(buf, plaintext[test])) { lwsl_err("Failed lws_b64 decode selftest " - "%d result '%s' / '%s', %d / %d\n", - test, buf, plaintext[test], n, strlen(plaintext[test])); + "%d result '%s' / '%s', %d / %zu\n", + test, buf, plaintext[test], n, + strlen(plaintext[test])); + lwsl_hexdump_err(buf, n); r = -1; } } - lwsl_notice("Base 64 selftests passed\n"); + if (!r) + lwsl_notice("Base 64 selftests passed\n"); + else + lwsl_notice("Base64 selftests failed\n"); return r; }