From aa45de9e2ab1fe6dba38dbc979a482bbf6cb2796 Mon Sep 17 00:00:00 2001 From: Andy Green Date: Sat, 2 Jan 2021 10:49:43 +0000 Subject: [PATCH] ss: enforce only valid state transitions The various stream transitions for direct ss, SSPC, smd, and different protocols are all handled in different code, let's stop hoping for the best and add a state transition validation function that is used everywhere we pass a state change to a user callback, and knows what is valid for the user state() callback to see next, given the last state it was shown. Let's assert if lws manages to violate that so we can find where the problem is and provide a stricter guarantee about what user state handler will see, no matter if ss or sspc or other cases. To facilitate that, move the states to start from 1, where 0 indicates the state unset. --- include/libwebsockets/lws-secure-streams.h | 6 +- lib/core-net/adopt.c | 7 +- lib/core-net/client/connect.c | 9 +- lib/core-net/dummy-callback.c | 3 +- .../private-lib-secure-streams.h | 10 ++ lib/secure-streams/protocols/ss-h1.c | 8 +- lib/secure-streams/secure-streams-client.c | 23 ++- lib/secure-streams/secure-streams-serialize.c | 13 +- lib/secure-streams/secure-streams.c | 155 +++++++++++++++++- 9 files changed, 213 insertions(+), 21 deletions(-) diff --git a/include/libwebsockets/lws-secure-streams.h b/include/libwebsockets/lws-secure-streams.h index a0bf5977c..bc1ba88d1 100644 --- a/include/libwebsockets/lws-secure-streams.h +++ b/include/libwebsockets/lws-secure-streams.h @@ -187,9 +187,13 @@ typedef uint32_t lws_ss_tx_ordinal_t; /* * connection state events + * + * If you add states, take care about the state names and state transition + * validity enforcement tables too */ typedef enum { - LWSSSCS_CREATING, + /* zero means unset */ + LWSSSCS_CREATING = 1, LWSSSCS_DISCONNECTED, LWSSSCS_UNREACHABLE, /* oridinal arg = 1 = caused by dns * server reachability failure */ diff --git a/lib/core-net/adopt.c b/lib/core-net/adopt.c index 2427d30c6..ffeef5141 100644 --- a/lib/core-net/adopt.c +++ b/lib/core-net/adopt.c @@ -301,8 +301,11 @@ lws_adopt_ss_server_accept(struct lws *new_wsi) goto fail; if (lws_ss_event_helper(h, LWSSSCS_CONNECTING)) goto fail; - if (lws_ss_event_helper(h, LWSSSCS_CONNECTED)) - goto fail; + + /* defer CONNECTED until we see if he is upgrading */ + +// if (lws_ss_event_helper(h, LWSSSCS_CONNECTED)) +// goto fail; // lwsl_notice("%s: accepted ss complete, pcol %s\n", __func__, // new_wsi->a.protocol->name); diff --git a/lib/core-net/client/connect.c b/lib/core-net/client/connect.c index 8d12d647a..dfe3f88bf 100644 --- a/lib/core-net/client/connect.c +++ b/lib/core-net/client/connect.c @@ -334,7 +334,14 @@ lws_client_connect_via_info(const struct lws_client_connect_info *i) __lws_lc_tag(&i->context->lcg[ #if defined(LWS_WITH_SECURE_STREAMS_PROXY_API) i->ssl_connection & LCCSCF_SECSTREAM_PROXY_LINK ? LWSLCG_WSI_SSP_CLIENT : - (i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD : LWSLCG_WSI_CLIENT)], +#if defined(LWS_WITH_SERVER) + (i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD : +#endif + LWSLCG_WSI_CLIENT +#if defined(LWS_WITH_SERVER) + ) +#endif + ], #else LWSLCG_WSI_CLIENT], #endif diff --git a/lib/core-net/dummy-callback.c b/lib/core-net/dummy-callback.c index 4a5afff7a..2152e8a18 100644 --- a/lib/core-net/dummy-callback.c +++ b/lib/core-net/dummy-callback.c @@ -697,9 +697,10 @@ lws_callback_http_dummy(struct lws *wsi, enum lws_callback_reasons reason, if (wsi->mux_substream && !wsi->cgi_stdout_zero_length) lws_write(wsi, (unsigned char *)buf + LWS_PRE, 0, LWS_WRITE_HTTP_FINAL); - +#if defined(LWS_WITH_SERVER) if (lws_http_transaction_completed(wsi)) return -1; +#endif return 0; case LWS_CALLBACK_CGI_STDIN_DATA: /* POST body for stdin */ diff --git a/lib/secure-streams/private-lib-secure-streams.h b/lib/secure-streams/private-lib-secure-streams.h index 4823c62e2..515206542 100644 --- a/lib/secure-streams/private-lib-secure-streams.h +++ b/lib/secure-streams/private-lib-secure-streams.h @@ -156,6 +156,7 @@ typedef struct lws_ss_handle { uint8_t tsi; /**< service thread idx, usually 0 */ uint8_t subseq; /**< emulate SOM tracking */ uint8_t txn_ok; /**< 1 = transaction was OK */ + uint8_t prev_ss_state; uint8_t txn_resp_set:1; /**< user code set one */ uint8_t txn_resp_pending:1; /**< we have yet to send */ @@ -300,6 +301,8 @@ typedef struct lws_sspc_handle { uint8_t rideshare_ofs[4]; uint8_t rsidx; + uint8_t prev_ss_state; + uint8_t conn_req_state:2; uint8_t destroying:1; uint8_t non_wsi:1; @@ -457,6 +460,13 @@ struct lws_vhost * lws_ss_policy_ref_trust_store(struct lws_context *context, const lws_ss_policy_t *pol, char doref); +lws_ss_state_return_t +lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs, + lws_ss_tx_ordinal_t flags); + +int +lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs); + void lws_proxy_clean_conn_ss(struct lws *wsi); diff --git a/lib/secure-streams/protocols/ss-h1.c b/lib/secure-streams/protocols/ss-h1.c index be8b98bb0..d4ad65664 100644 --- a/lib/secure-streams/protocols/ss-h1.c +++ b/lib/secure-streams/protocols/ss-h1.c @@ -514,9 +514,11 @@ secstream_h1(struct lws *wsi, enum lws_callback_reasons reason, void *user, h->seqstate = SSSEQ_CONNECTED; lws_sul_cancel(&h->sul); - r = lws_ss_event_helper(h, LWSSSCS_CONNECTED); - if (r != LWSSSSRET_OK) - return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h); + if (h->prev_ss_state != LWSSSCS_CONNECTED) { + r = lws_ss_event_helper(h, LWSSSCS_CONNECTED); + if (r != LWSSSSRET_OK) + return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h); + } /* * Since it's an http transaction we initiated... this is diff --git a/lib/secure-streams/secure-streams-client.c b/lib/secure-streams/secure-streams-client.c index 3c83b913e..eed5b2342 100644 --- a/lib/secure-streams/secure-streams-client.c +++ b/lib/secure-streams/secure-streams-client.c @@ -19,6 +19,22 @@ */ #include +lws_ss_state_return_t +lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs, + lws_ss_tx_ordinal_t flags) +{ + if (!h) + return LWSSSSRET_OK; + + if (lws_ss_check_next_state(&h->prev_ss_state, cs)) + return LWSSSSRET_DESTROY_ME; + + if (!h->ssi.state) + return LWSSSSRET_OK; + + return h->ssi.state((void *)((uint8_t *)&h[1]), NULL, cs, flags); +} + static void lws_sspc_sul_retry_cb(lws_sorted_usec_list_t *sul) { @@ -533,7 +549,6 @@ void lws_sspc_destroy(lws_sspc_handle_t **ph) { lws_sspc_handle_t *h; - void *m; lwsl_debug("%s\n", __func__); @@ -541,7 +556,6 @@ lws_sspc_destroy(lws_sspc_handle_t **ph) return; h = *ph; - m = (void *)((uint8_t *)&h[1]); if (h->destroying) return; @@ -549,7 +563,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph) h->destroying = 1; if (h->ss_dangling_connected && h->ssi.state) { - h->ssi.state(m, NULL, LWSSSCS_DISCONNECTED, 0); + lws_sspc_event_helper(h, LWSSSCS_DISCONNECTED, 0); h->ss_dangling_connected = 0; } @@ -578,8 +592,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph) lws_sspc_rxmetadata_destroy(h); - if (h->ssi.state) - h->ssi.state(m, NULL, LWSSSCS_DESTROYING, 0); + lws_sspc_event_helper(h, LWSSSCS_DESTROYING, 0); *ph = NULL; __lws_lc_untag(&h->lc); diff --git a/lib/secure-streams/secure-streams-serialize.c b/lib/secure-streams/secure-streams-serialize.c index ef9b6f9c8..833ae16a5 100644 --- a/lib/secure-streams/secure-streams-serialize.c +++ b/lib/secure-streams/secure-streams-serialize.c @@ -1240,6 +1240,11 @@ payload_ff: h->creating_cb_done = 1; + if (lws_ss_check_next_state(&h->prev_ss_state, LWSSSCS_CREATING)) + return LWSSSSRET_DESTROY_ME; + + h->prev_ss_state = (uint8_t)LWSSSCS_CREATING; + if (ssi->state) { n = ssi->state(client_pss_to_userdata(pss), NULL, LWSSSCS_CREATING, 0); @@ -1391,8 +1396,14 @@ payload_ff: if (cs == LWSSSCS_DISCONNECTED) h->ss_dangling_connected = 0; + if (lws_ss_check_next_state(&h->prev_ss_state, cs)) + return LWSSSSRET_DESTROY_ME; + + if (cs < LWSSSCS_USER_BASE) + h->prev_ss_state = (uint8_t)cs; + n = ssi->state(client_pss_to_userdata(pss), - NULL, (lws_ss_constate_t)par->ctr, par->flags); + NULL, cs, par->flags); switch (n) { case LWSSSSRET_OK: break; diff --git a/lib/secure-streams/secure-streams.c b/lib/secure-streams/secure-streams.c index b39007164..bec3a8bab 100644 --- a/lib/secure-streams/secure-streams.c +++ b/lib/secure-streams/secure-streams.c @@ -50,6 +50,7 @@ static const struct ss_pcols *ss_pcols[] = { }; static const char *state_names[] = { + "(unset)", "LWSSSCS_CREATING", "LWSSSCS_DISCONNECTED", "LWSSSCS_UNREACHABLE", @@ -68,6 +69,144 @@ static const char *state_names[] = { "LWSSSCS_SERVER_UPGRADE", }; +/* + * For each "current state", set bit offsets for valid "next states". + * + * Since there are complicated ways to arrive at state transitions like proxying + * and asynchronous destruction etc, so we monitor the state transitions we are + * giving the ss user code to ensure we never deliver illegal state transitions + * (because we will assert if we have bugs that do it) + */ + +static const uint32_t ss_state_txn_validity[] = { + + /* if we was last in this state... we can legally go to these states */ + + [0] = (1 << LWSSSCS_CREATING) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_CREATING] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_SERVER_UPGRADE) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_DISCONNECTED] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_UNREACHABLE] = (1 << LWSSSCS_ALL_RETRIES_FAILED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_CONNECTING) | + /* win conn failure > retry > succ */ + (1 << LWSSSCS_CONNECTED) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_AUTH_FAILED] = (1 << LWSSSCS_ALL_RETRIES_FAILED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_CONNECTED] = (1 << LWSSSCS_SERVER_UPGRADE) | + (1 << LWSSSCS_AUTH_FAILED) | + (1 << LWSSSCS_QOS_ACK_REMOTE) | + (1 << LWSSSCS_QOS_NACK_REMOTE) | + (1 << LWSSSCS_QOS_ACK_LOCAL) | + (1 << LWSSSCS_QOS_NACK_LOCAL) | + (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_CONNECTING] = (1 << LWSSSCS_UNREACHABLE) | + (1 << LWSSSCS_AUTH_FAILED) | + (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_CONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_DESTROYING] = 0, + + [LWSSSCS_POLL] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_ALL_RETRIES_FAILED] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_ACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_NACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_ACK_LOCAL] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_NACK_LOCAL] = (1 << LWSSSCS_DESTROYING) | + (1 << LWSSSCS_TIMEOUT), + + [LWSSSCS_TIMEOUT] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_SERVER_TXN] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_SERVER_UPGRADE] = (1 << LWSSSCS_SERVER_TXN) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_DESTROYING), +}; + +int +lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs) +{ + if (cs >= LWSSSCS_USER_BASE) + /* + * we can't judge user states, leave the old state and + * just wave them through + */ + return 0; + + if (cs >= LWS_ARRAY_SIZE(ss_state_txn_validity)) { + /* we don't recognize this state as usable */ + lwsl_err("%s: bad new state %u\n", __func__, cs); + assert(0); + return 1; + } + + if (*prevstate >= LWS_ARRAY_SIZE(ss_state_txn_validity)) { + /* existing state is broken */ + lwsl_err("%s: bad existing state %u\n", __func__, + (unsigned int)*prevstate); + assert(0); + return 1; + } + + if (ss_state_txn_validity[*prevstate] & (1u << cs)) { + /* this is explicitly allowed, update old state to new */ + *prevstate = (uint8_t)cs; + + return 0; + } + + lwsl_err("%s: transition from %s -> %s is illegal\n", __func__, + lws_ss_state_name((int)*prevstate), + lws_ss_state_name((int)cs)); + + assert(0); + + return 1; +} + const char * lws_ss_state_name(int state) { @@ -88,6 +227,9 @@ lws_ss_event_helper(lws_ss_handle_t *h, lws_ss_constate_t cs) if (!h) return LWSSSSRET_OK; + if (lws_ss_check_next_state(&h->prev_ss_state, cs)) + return LWSSSSRET_DESTROY_ME; + if (cs == LWSSSCS_CONNECTED) h->ss_dangling_connected = 1; if (cs == LWSSSCS_DISCONNECTED) @@ -707,7 +849,7 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi, */ if (!(ssi->flags & LWSSSINFLAGS_PROXIED) && pol == &pol_smd) { - lws_ss_state_return_t r; + /* * So he has asked to be wired up to SMD over a SS link. * Register him as an smd participant in his own right. @@ -722,12 +864,6 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi, if (!h->u.smd.smd_peer) goto late_bail; lwsl_info("%s: registered SS SMD\n", __func__); - r = lws_ss_event_helper(h, LWSSSCS_CONNECTING); - if (r) - return r; - r = lws_ss_event_helper(h, LWSSSCS_CONNECTED); - if (r) - return r; } #endif @@ -798,6 +934,11 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi, */ vho->ss_handle = h; + r = lws_ss_event_helper(h, LWSSSCS_CREATING); + lwsl_info("%s: CREATING returned status %d\n", __func__, (int)r); + if (r == LWSSSSRET_DESTROY_ME) + goto late_bail; + lwsl_notice("%s: created server %s\n", __func__, h->policy->streamtype);