1
0
Fork 0
mirror of https://github.com/warmcat/libwebsockets.git synced 2025-03-09 00:00:04 +01:00

ss: enforce only valid state transitions

The various stream transitions for direct ss, SSPC, smd, and
different protocols are all handled in different code, let's
stop hoping for the best and add a state transition validation
function that is used everywhere we pass a state change to a
user callback, and knows what is valid for the user state()
callback to see next, given the last state it was shown.

Let's assert if lws manages to violate that so we can find
where the problem is and provide a stricter guarantee about
what user state handler will see, no matter if ss or sspc
or other cases.

To facilitate that, move the states to start from 1, where
0 indicates the state unset.
This commit is contained in:
Andy Green 2021-01-02 10:49:43 +00:00
parent 47905401fa
commit aa45de9e2a
9 changed files with 213 additions and 21 deletions

View file

@ -187,9 +187,13 @@ typedef uint32_t lws_ss_tx_ordinal_t;
/*
* connection state events
*
* If you add states, take care about the state names and state transition
* validity enforcement tables too
*/
typedef enum {
LWSSSCS_CREATING,
/* zero means unset */
LWSSSCS_CREATING = 1,
LWSSSCS_DISCONNECTED,
LWSSSCS_UNREACHABLE, /* oridinal arg = 1 = caused by dns
* server reachability failure */

View file

@ -301,8 +301,11 @@ lws_adopt_ss_server_accept(struct lws *new_wsi)
goto fail;
if (lws_ss_event_helper(h, LWSSSCS_CONNECTING))
goto fail;
if (lws_ss_event_helper(h, LWSSSCS_CONNECTED))
goto fail;
/* defer CONNECTED until we see if he is upgrading */
// if (lws_ss_event_helper(h, LWSSSCS_CONNECTED))
// goto fail;
// lwsl_notice("%s: accepted ss complete, pcol %s\n", __func__,
// new_wsi->a.protocol->name);

View file

@ -334,7 +334,14 @@ lws_client_connect_via_info(const struct lws_client_connect_info *i)
__lws_lc_tag(&i->context->lcg[
#if defined(LWS_WITH_SECURE_STREAMS_PROXY_API)
i->ssl_connection & LCCSCF_SECSTREAM_PROXY_LINK ? LWSLCG_WSI_SSP_CLIENT :
(i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD : LWSLCG_WSI_CLIENT)],
#if defined(LWS_WITH_SERVER)
(i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD :
#endif
LWSLCG_WSI_CLIENT
#if defined(LWS_WITH_SERVER)
)
#endif
],
#else
LWSLCG_WSI_CLIENT],
#endif

View file

@ -697,9 +697,10 @@ lws_callback_http_dummy(struct lws *wsi, enum lws_callback_reasons reason,
if (wsi->mux_substream && !wsi->cgi_stdout_zero_length)
lws_write(wsi, (unsigned char *)buf + LWS_PRE, 0,
LWS_WRITE_HTTP_FINAL);
#if defined(LWS_WITH_SERVER)
if (lws_http_transaction_completed(wsi))
return -1;
#endif
return 0;
case LWS_CALLBACK_CGI_STDIN_DATA: /* POST body for stdin */

View file

@ -156,6 +156,7 @@ typedef struct lws_ss_handle {
uint8_t tsi; /**< service thread idx, usually 0 */
uint8_t subseq; /**< emulate SOM tracking */
uint8_t txn_ok; /**< 1 = transaction was OK */
uint8_t prev_ss_state;
uint8_t txn_resp_set:1; /**< user code set one */
uint8_t txn_resp_pending:1; /**< we have yet to send */
@ -300,6 +301,8 @@ typedef struct lws_sspc_handle {
uint8_t rideshare_ofs[4];
uint8_t rsidx;
uint8_t prev_ss_state;
uint8_t conn_req_state:2;
uint8_t destroying:1;
uint8_t non_wsi:1;
@ -457,6 +460,13 @@ struct lws_vhost *
lws_ss_policy_ref_trust_store(struct lws_context *context,
const lws_ss_policy_t *pol, char doref);
lws_ss_state_return_t
lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs,
lws_ss_tx_ordinal_t flags);
int
lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs);
void
lws_proxy_clean_conn_ss(struct lws *wsi);

View file

@ -514,9 +514,11 @@ secstream_h1(struct lws *wsi, enum lws_callback_reasons reason, void *user,
h->seqstate = SSSEQ_CONNECTED;
lws_sul_cancel(&h->sul);
r = lws_ss_event_helper(h, LWSSSCS_CONNECTED);
if (r != LWSSSSRET_OK)
return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h);
if (h->prev_ss_state != LWSSSCS_CONNECTED) {
r = lws_ss_event_helper(h, LWSSSCS_CONNECTED);
if (r != LWSSSSRET_OK)
return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h);
}
/*
* Since it's an http transaction we initiated... this is

View file

@ -19,6 +19,22 @@
*/
#include <private-lib-core.h>
lws_ss_state_return_t
lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs,
lws_ss_tx_ordinal_t flags)
{
if (!h)
return LWSSSSRET_OK;
if (lws_ss_check_next_state(&h->prev_ss_state, cs))
return LWSSSSRET_DESTROY_ME;
if (!h->ssi.state)
return LWSSSSRET_OK;
return h->ssi.state((void *)((uint8_t *)&h[1]), NULL, cs, flags);
}
static void
lws_sspc_sul_retry_cb(lws_sorted_usec_list_t *sul)
{
@ -533,7 +549,6 @@ void
lws_sspc_destroy(lws_sspc_handle_t **ph)
{
lws_sspc_handle_t *h;
void *m;
lwsl_debug("%s\n", __func__);
@ -541,7 +556,6 @@ lws_sspc_destroy(lws_sspc_handle_t **ph)
return;
h = *ph;
m = (void *)((uint8_t *)&h[1]);
if (h->destroying)
return;
@ -549,7 +563,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph)
h->destroying = 1;
if (h->ss_dangling_connected && h->ssi.state) {
h->ssi.state(m, NULL, LWSSSCS_DISCONNECTED, 0);
lws_sspc_event_helper(h, LWSSSCS_DISCONNECTED, 0);
h->ss_dangling_connected = 0;
}
@ -578,8 +592,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph)
lws_sspc_rxmetadata_destroy(h);
if (h->ssi.state)
h->ssi.state(m, NULL, LWSSSCS_DESTROYING, 0);
lws_sspc_event_helper(h, LWSSSCS_DESTROYING, 0);
*ph = NULL;
__lws_lc_untag(&h->lc);

View file

@ -1240,6 +1240,11 @@ payload_ff:
h->creating_cb_done = 1;
if (lws_ss_check_next_state(&h->prev_ss_state, LWSSSCS_CREATING))
return LWSSSSRET_DESTROY_ME;
h->prev_ss_state = (uint8_t)LWSSSCS_CREATING;
if (ssi->state) {
n = ssi->state(client_pss_to_userdata(pss),
NULL, LWSSSCS_CREATING, 0);
@ -1391,8 +1396,14 @@ payload_ff:
if (cs == LWSSSCS_DISCONNECTED)
h->ss_dangling_connected = 0;
if (lws_ss_check_next_state(&h->prev_ss_state, cs))
return LWSSSSRET_DESTROY_ME;
if (cs < LWSSSCS_USER_BASE)
h->prev_ss_state = (uint8_t)cs;
n = ssi->state(client_pss_to_userdata(pss),
NULL, (lws_ss_constate_t)par->ctr, par->flags);
NULL, cs, par->flags);
switch (n) {
case LWSSSSRET_OK:
break;

View file

@ -50,6 +50,7 @@ static const struct ss_pcols *ss_pcols[] = {
};
static const char *state_names[] = {
"(unset)",
"LWSSSCS_CREATING",
"LWSSSCS_DISCONNECTED",
"LWSSSCS_UNREACHABLE",
@ -68,6 +69,144 @@ static const char *state_names[] = {
"LWSSSCS_SERVER_UPGRADE",
};
/*
* For each "current state", set bit offsets for valid "next states".
*
* Since there are complicated ways to arrive at state transitions like proxying
* and asynchronous destruction etc, so we monitor the state transitions we are
* giving the ss user code to ensure we never deliver illegal state transitions
* (because we will assert if we have bugs that do it)
*/
static const uint32_t ss_state_txn_validity[] = {
/* if we was last in this state... we can legally go to these states */
[0] = (1 << LWSSSCS_CREATING) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_CREATING] = (1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_POLL) |
(1 << LWSSSCS_SERVER_UPGRADE) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_DISCONNECTED] = (1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_POLL) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_UNREACHABLE] = (1 << LWSSSCS_ALL_RETRIES_FAILED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_POLL) |
(1 << LWSSSCS_CONNECTING) |
/* win conn failure > retry > succ */
(1 << LWSSSCS_CONNECTED) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_AUTH_FAILED] = (1 << LWSSSCS_ALL_RETRIES_FAILED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_CONNECTED] = (1 << LWSSSCS_SERVER_UPGRADE) |
(1 << LWSSSCS_AUTH_FAILED) |
(1 << LWSSSCS_QOS_ACK_REMOTE) |
(1 << LWSSSCS_QOS_NACK_REMOTE) |
(1 << LWSSSCS_QOS_ACK_LOCAL) |
(1 << LWSSSCS_QOS_NACK_LOCAL) |
(1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_CONNECTING] = (1 << LWSSSCS_UNREACHABLE) |
(1 << LWSSSCS_AUTH_FAILED) |
(1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_CONNECTED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_DESTROYING] = 0,
[LWSSSCS_POLL] = (1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_ALL_RETRIES_FAILED] = (1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_QOS_ACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_QOS_NACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_QOS_ACK_LOCAL] = (1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_QOS_NACK_LOCAL] = (1 << LWSSSCS_DESTROYING) |
(1 << LWSSSCS_TIMEOUT),
[LWSSSCS_TIMEOUT] = (1 << LWSSSCS_CONNECTING) |
(1 << LWSSSCS_POLL) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_SERVER_TXN] = (1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DESTROYING),
[LWSSSCS_SERVER_UPGRADE] = (1 << LWSSSCS_SERVER_TXN) |
(1 << LWSSSCS_TIMEOUT) |
(1 << LWSSSCS_DISCONNECTED) |
(1 << LWSSSCS_DESTROYING),
};
int
lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs)
{
if (cs >= LWSSSCS_USER_BASE)
/*
* we can't judge user states, leave the old state and
* just wave them through
*/
return 0;
if (cs >= LWS_ARRAY_SIZE(ss_state_txn_validity)) {
/* we don't recognize this state as usable */
lwsl_err("%s: bad new state %u\n", __func__, cs);
assert(0);
return 1;
}
if (*prevstate >= LWS_ARRAY_SIZE(ss_state_txn_validity)) {
/* existing state is broken */
lwsl_err("%s: bad existing state %u\n", __func__,
(unsigned int)*prevstate);
assert(0);
return 1;
}
if (ss_state_txn_validity[*prevstate] & (1u << cs)) {
/* this is explicitly allowed, update old state to new */
*prevstate = (uint8_t)cs;
return 0;
}
lwsl_err("%s: transition from %s -> %s is illegal\n", __func__,
lws_ss_state_name((int)*prevstate),
lws_ss_state_name((int)cs));
assert(0);
return 1;
}
const char *
lws_ss_state_name(int state)
{
@ -88,6 +227,9 @@ lws_ss_event_helper(lws_ss_handle_t *h, lws_ss_constate_t cs)
if (!h)
return LWSSSSRET_OK;
if (lws_ss_check_next_state(&h->prev_ss_state, cs))
return LWSSSSRET_DESTROY_ME;
if (cs == LWSSSCS_CONNECTED)
h->ss_dangling_connected = 1;
if (cs == LWSSSCS_DISCONNECTED)
@ -707,7 +849,7 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi,
*/
if (!(ssi->flags & LWSSSINFLAGS_PROXIED) &&
pol == &pol_smd) {
lws_ss_state_return_t r;
/*
* So he has asked to be wired up to SMD over a SS link.
* Register him as an smd participant in his own right.
@ -722,12 +864,6 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi,
if (!h->u.smd.smd_peer)
goto late_bail;
lwsl_info("%s: registered SS SMD\n", __func__);
r = lws_ss_event_helper(h, LWSSSCS_CONNECTING);
if (r)
return r;
r = lws_ss_event_helper(h, LWSSSCS_CONNECTED);
if (r)
return r;
}
#endif
@ -798,6 +934,11 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi,
*/
vho->ss_handle = h;
r = lws_ss_event_helper(h, LWSSSCS_CREATING);
lwsl_info("%s: CREATING returned status %d\n", __func__, (int)r);
if (r == LWSSSSRET_DESTROY_ME)
goto late_bail;
lwsl_notice("%s: created server %s\n", __func__,
h->policy->streamtype);