diff --git a/include/libwebsockets/lws-secure-streams.h b/include/libwebsockets/lws-secure-streams.h index a0bf5977c..bc1ba88d1 100644 --- a/include/libwebsockets/lws-secure-streams.h +++ b/include/libwebsockets/lws-secure-streams.h @@ -187,9 +187,13 @@ typedef uint32_t lws_ss_tx_ordinal_t; /* * connection state events + * + * If you add states, take care about the state names and state transition + * validity enforcement tables too */ typedef enum { - LWSSSCS_CREATING, + /* zero means unset */ + LWSSSCS_CREATING = 1, LWSSSCS_DISCONNECTED, LWSSSCS_UNREACHABLE, /* oridinal arg = 1 = caused by dns * server reachability failure */ diff --git a/lib/core-net/adopt.c b/lib/core-net/adopt.c index 2427d30c6..ffeef5141 100644 --- a/lib/core-net/adopt.c +++ b/lib/core-net/adopt.c @@ -301,8 +301,11 @@ lws_adopt_ss_server_accept(struct lws *new_wsi) goto fail; if (lws_ss_event_helper(h, LWSSSCS_CONNECTING)) goto fail; - if (lws_ss_event_helper(h, LWSSSCS_CONNECTED)) - goto fail; + + /* defer CONNECTED until we see if he is upgrading */ + +// if (lws_ss_event_helper(h, LWSSSCS_CONNECTED)) +// goto fail; // lwsl_notice("%s: accepted ss complete, pcol %s\n", __func__, // new_wsi->a.protocol->name); diff --git a/lib/core-net/client/connect.c b/lib/core-net/client/connect.c index 8d12d647a..dfe3f88bf 100644 --- a/lib/core-net/client/connect.c +++ b/lib/core-net/client/connect.c @@ -334,7 +334,14 @@ lws_client_connect_via_info(const struct lws_client_connect_info *i) __lws_lc_tag(&i->context->lcg[ #if defined(LWS_WITH_SECURE_STREAMS_PROXY_API) i->ssl_connection & LCCSCF_SECSTREAM_PROXY_LINK ? LWSLCG_WSI_SSP_CLIENT : - (i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD : LWSLCG_WSI_CLIENT)], +#if defined(LWS_WITH_SERVER) + (i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD : +#endif + LWSLCG_WSI_CLIENT +#if defined(LWS_WITH_SERVER) + ) +#endif + ], #else LWSLCG_WSI_CLIENT], #endif diff --git a/lib/core-net/dummy-callback.c b/lib/core-net/dummy-callback.c index 4a5afff7a..2152e8a18 100644 --- a/lib/core-net/dummy-callback.c +++ b/lib/core-net/dummy-callback.c @@ -697,9 +697,10 @@ lws_callback_http_dummy(struct lws *wsi, enum lws_callback_reasons reason, if (wsi->mux_substream && !wsi->cgi_stdout_zero_length) lws_write(wsi, (unsigned char *)buf + LWS_PRE, 0, LWS_WRITE_HTTP_FINAL); - +#if defined(LWS_WITH_SERVER) if (lws_http_transaction_completed(wsi)) return -1; +#endif return 0; case LWS_CALLBACK_CGI_STDIN_DATA: /* POST body for stdin */ diff --git a/lib/secure-streams/private-lib-secure-streams.h b/lib/secure-streams/private-lib-secure-streams.h index 4823c62e2..515206542 100644 --- a/lib/secure-streams/private-lib-secure-streams.h +++ b/lib/secure-streams/private-lib-secure-streams.h @@ -156,6 +156,7 @@ typedef struct lws_ss_handle { uint8_t tsi; /**< service thread idx, usually 0 */ uint8_t subseq; /**< emulate SOM tracking */ uint8_t txn_ok; /**< 1 = transaction was OK */ + uint8_t prev_ss_state; uint8_t txn_resp_set:1; /**< user code set one */ uint8_t txn_resp_pending:1; /**< we have yet to send */ @@ -300,6 +301,8 @@ typedef struct lws_sspc_handle { uint8_t rideshare_ofs[4]; uint8_t rsidx; + uint8_t prev_ss_state; + uint8_t conn_req_state:2; uint8_t destroying:1; uint8_t non_wsi:1; @@ -457,6 +460,13 @@ struct lws_vhost * lws_ss_policy_ref_trust_store(struct lws_context *context, const lws_ss_policy_t *pol, char doref); +lws_ss_state_return_t +lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs, + lws_ss_tx_ordinal_t flags); + +int +lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs); + void lws_proxy_clean_conn_ss(struct lws *wsi); diff --git a/lib/secure-streams/protocols/ss-h1.c b/lib/secure-streams/protocols/ss-h1.c index be8b98bb0..d4ad65664 100644 --- a/lib/secure-streams/protocols/ss-h1.c +++ b/lib/secure-streams/protocols/ss-h1.c @@ -514,9 +514,11 @@ secstream_h1(struct lws *wsi, enum lws_callback_reasons reason, void *user, h->seqstate = SSSEQ_CONNECTED; lws_sul_cancel(&h->sul); - r = lws_ss_event_helper(h, LWSSSCS_CONNECTED); - if (r != LWSSSSRET_OK) - return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h); + if (h->prev_ss_state != LWSSSCS_CONNECTED) { + r = lws_ss_event_helper(h, LWSSSCS_CONNECTED); + if (r != LWSSSSRET_OK) + return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h); + } /* * Since it's an http transaction we initiated... this is diff --git a/lib/secure-streams/secure-streams-client.c b/lib/secure-streams/secure-streams-client.c index 3c83b913e..eed5b2342 100644 --- a/lib/secure-streams/secure-streams-client.c +++ b/lib/secure-streams/secure-streams-client.c @@ -19,6 +19,22 @@ */ #include +lws_ss_state_return_t +lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs, + lws_ss_tx_ordinal_t flags) +{ + if (!h) + return LWSSSSRET_OK; + + if (lws_ss_check_next_state(&h->prev_ss_state, cs)) + return LWSSSSRET_DESTROY_ME; + + if (!h->ssi.state) + return LWSSSSRET_OK; + + return h->ssi.state((void *)((uint8_t *)&h[1]), NULL, cs, flags); +} + static void lws_sspc_sul_retry_cb(lws_sorted_usec_list_t *sul) { @@ -533,7 +549,6 @@ void lws_sspc_destroy(lws_sspc_handle_t **ph) { lws_sspc_handle_t *h; - void *m; lwsl_debug("%s\n", __func__); @@ -541,7 +556,6 @@ lws_sspc_destroy(lws_sspc_handle_t **ph) return; h = *ph; - m = (void *)((uint8_t *)&h[1]); if (h->destroying) return; @@ -549,7 +563,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph) h->destroying = 1; if (h->ss_dangling_connected && h->ssi.state) { - h->ssi.state(m, NULL, LWSSSCS_DISCONNECTED, 0); + lws_sspc_event_helper(h, LWSSSCS_DISCONNECTED, 0); h->ss_dangling_connected = 0; } @@ -578,8 +592,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph) lws_sspc_rxmetadata_destroy(h); - if (h->ssi.state) - h->ssi.state(m, NULL, LWSSSCS_DESTROYING, 0); + lws_sspc_event_helper(h, LWSSSCS_DESTROYING, 0); *ph = NULL; __lws_lc_untag(&h->lc); diff --git a/lib/secure-streams/secure-streams-serialize.c b/lib/secure-streams/secure-streams-serialize.c index ef9b6f9c8..833ae16a5 100644 --- a/lib/secure-streams/secure-streams-serialize.c +++ b/lib/secure-streams/secure-streams-serialize.c @@ -1240,6 +1240,11 @@ payload_ff: h->creating_cb_done = 1; + if (lws_ss_check_next_state(&h->prev_ss_state, LWSSSCS_CREATING)) + return LWSSSSRET_DESTROY_ME; + + h->prev_ss_state = (uint8_t)LWSSSCS_CREATING; + if (ssi->state) { n = ssi->state(client_pss_to_userdata(pss), NULL, LWSSSCS_CREATING, 0); @@ -1391,8 +1396,14 @@ payload_ff: if (cs == LWSSSCS_DISCONNECTED) h->ss_dangling_connected = 0; + if (lws_ss_check_next_state(&h->prev_ss_state, cs)) + return LWSSSSRET_DESTROY_ME; + + if (cs < LWSSSCS_USER_BASE) + h->prev_ss_state = (uint8_t)cs; + n = ssi->state(client_pss_to_userdata(pss), - NULL, (lws_ss_constate_t)par->ctr, par->flags); + NULL, cs, par->flags); switch (n) { case LWSSSSRET_OK: break; diff --git a/lib/secure-streams/secure-streams.c b/lib/secure-streams/secure-streams.c index b39007164..bec3a8bab 100644 --- a/lib/secure-streams/secure-streams.c +++ b/lib/secure-streams/secure-streams.c @@ -50,6 +50,7 @@ static const struct ss_pcols *ss_pcols[] = { }; static const char *state_names[] = { + "(unset)", "LWSSSCS_CREATING", "LWSSSCS_DISCONNECTED", "LWSSSCS_UNREACHABLE", @@ -68,6 +69,144 @@ static const char *state_names[] = { "LWSSSCS_SERVER_UPGRADE", }; +/* + * For each "current state", set bit offsets for valid "next states". + * + * Since there are complicated ways to arrive at state transitions like proxying + * and asynchronous destruction etc, so we monitor the state transitions we are + * giving the ss user code to ensure we never deliver illegal state transitions + * (because we will assert if we have bugs that do it) + */ + +static const uint32_t ss_state_txn_validity[] = { + + /* if we was last in this state... we can legally go to these states */ + + [0] = (1 << LWSSSCS_CREATING) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_CREATING] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_SERVER_UPGRADE) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_DISCONNECTED] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_UNREACHABLE] = (1 << LWSSSCS_ALL_RETRIES_FAILED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_CONNECTING) | + /* win conn failure > retry > succ */ + (1 << LWSSSCS_CONNECTED) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_AUTH_FAILED] = (1 << LWSSSCS_ALL_RETRIES_FAILED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_CONNECTED] = (1 << LWSSSCS_SERVER_UPGRADE) | + (1 << LWSSSCS_AUTH_FAILED) | + (1 << LWSSSCS_QOS_ACK_REMOTE) | + (1 << LWSSSCS_QOS_NACK_REMOTE) | + (1 << LWSSSCS_QOS_ACK_LOCAL) | + (1 << LWSSSCS_QOS_NACK_LOCAL) | + (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_CONNECTING] = (1 << LWSSSCS_UNREACHABLE) | + (1 << LWSSSCS_AUTH_FAILED) | + (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_CONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_DESTROYING] = 0, + + [LWSSSCS_POLL] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_ALL_RETRIES_FAILED] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_ACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_NACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_ACK_LOCAL] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_QOS_NACK_LOCAL] = (1 << LWSSSCS_DESTROYING) | + (1 << LWSSSCS_TIMEOUT), + + [LWSSSCS_TIMEOUT] = (1 << LWSSSCS_CONNECTING) | + (1 << LWSSSCS_POLL) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_SERVER_TXN] = (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DESTROYING), + + [LWSSSCS_SERVER_UPGRADE] = (1 << LWSSSCS_SERVER_TXN) | + (1 << LWSSSCS_TIMEOUT) | + (1 << LWSSSCS_DISCONNECTED) | + (1 << LWSSSCS_DESTROYING), +}; + +int +lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs) +{ + if (cs >= LWSSSCS_USER_BASE) + /* + * we can't judge user states, leave the old state and + * just wave them through + */ + return 0; + + if (cs >= LWS_ARRAY_SIZE(ss_state_txn_validity)) { + /* we don't recognize this state as usable */ + lwsl_err("%s: bad new state %u\n", __func__, cs); + assert(0); + return 1; + } + + if (*prevstate >= LWS_ARRAY_SIZE(ss_state_txn_validity)) { + /* existing state is broken */ + lwsl_err("%s: bad existing state %u\n", __func__, + (unsigned int)*prevstate); + assert(0); + return 1; + } + + if (ss_state_txn_validity[*prevstate] & (1u << cs)) { + /* this is explicitly allowed, update old state to new */ + *prevstate = (uint8_t)cs; + + return 0; + } + + lwsl_err("%s: transition from %s -> %s is illegal\n", __func__, + lws_ss_state_name((int)*prevstate), + lws_ss_state_name((int)cs)); + + assert(0); + + return 1; +} + const char * lws_ss_state_name(int state) { @@ -88,6 +227,9 @@ lws_ss_event_helper(lws_ss_handle_t *h, lws_ss_constate_t cs) if (!h) return LWSSSSRET_OK; + if (lws_ss_check_next_state(&h->prev_ss_state, cs)) + return LWSSSSRET_DESTROY_ME; + if (cs == LWSSSCS_CONNECTED) h->ss_dangling_connected = 1; if (cs == LWSSSCS_DISCONNECTED) @@ -707,7 +849,7 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi, */ if (!(ssi->flags & LWSSSINFLAGS_PROXIED) && pol == &pol_smd) { - lws_ss_state_return_t r; + /* * So he has asked to be wired up to SMD over a SS link. * Register him as an smd participant in his own right. @@ -722,12 +864,6 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi, if (!h->u.smd.smd_peer) goto late_bail; lwsl_info("%s: registered SS SMD\n", __func__); - r = lws_ss_event_helper(h, LWSSSCS_CONNECTING); - if (r) - return r; - r = lws_ss_event_helper(h, LWSSSCS_CONNECTED); - if (r) - return r; } #endif @@ -798,6 +934,11 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi, */ vho->ss_handle = h; + r = lws_ss_event_helper(h, LWSSSCS_CREATING); + lwsl_info("%s: CREATING returned status %d\n", __func__, (int)r); + if (r == LWSSSSRET_DESTROY_ME) + goto late_bail; + lwsl_notice("%s: created server %s\n", __func__, h->policy->streamtype);