mirror of
https://github.com/warmcat/libwebsockets.git
synced 2025-03-09 00:00:04 +01:00
ss: enforce only valid state transitions
The various stream transitions for direct ss, SSPC, smd, and different protocols are all handled in different code, let's stop hoping for the best and add a state transition validation function that is used everywhere we pass a state change to a user callback, and knows what is valid for the user state() callback to see next, given the last state it was shown. Let's assert if lws manages to violate that so we can find where the problem is and provide a stricter guarantee about what user state handler will see, no matter if ss or sspc or other cases. To facilitate that, move the states to start from 1, where 0 indicates the state unset.
This commit is contained in:
parent
47905401fa
commit
aa45de9e2a
9 changed files with 213 additions and 21 deletions
|
@ -187,9 +187,13 @@ typedef uint32_t lws_ss_tx_ordinal_t;
|
|||
|
||||
/*
|
||||
* connection state events
|
||||
*
|
||||
* If you add states, take care about the state names and state transition
|
||||
* validity enforcement tables too
|
||||
*/
|
||||
typedef enum {
|
||||
LWSSSCS_CREATING,
|
||||
/* zero means unset */
|
||||
LWSSSCS_CREATING = 1,
|
||||
LWSSSCS_DISCONNECTED,
|
||||
LWSSSCS_UNREACHABLE, /* oridinal arg = 1 = caused by dns
|
||||
* server reachability failure */
|
||||
|
|
|
@ -301,8 +301,11 @@ lws_adopt_ss_server_accept(struct lws *new_wsi)
|
|||
goto fail;
|
||||
if (lws_ss_event_helper(h, LWSSSCS_CONNECTING))
|
||||
goto fail;
|
||||
if (lws_ss_event_helper(h, LWSSSCS_CONNECTED))
|
||||
goto fail;
|
||||
|
||||
/* defer CONNECTED until we see if he is upgrading */
|
||||
|
||||
// if (lws_ss_event_helper(h, LWSSSCS_CONNECTED))
|
||||
// goto fail;
|
||||
|
||||
// lwsl_notice("%s: accepted ss complete, pcol %s\n", __func__,
|
||||
// new_wsi->a.protocol->name);
|
||||
|
|
|
@ -334,7 +334,14 @@ lws_client_connect_via_info(const struct lws_client_connect_info *i)
|
|||
__lws_lc_tag(&i->context->lcg[
|
||||
#if defined(LWS_WITH_SECURE_STREAMS_PROXY_API)
|
||||
i->ssl_connection & LCCSCF_SECSTREAM_PROXY_LINK ? LWSLCG_WSI_SSP_CLIENT :
|
||||
(i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD : LWSLCG_WSI_CLIENT)],
|
||||
#if defined(LWS_WITH_SERVER)
|
||||
(i->ssl_connection & LCCSCF_SECSTREAM_PROXY_ONWARD ? LWSLCG_WSI_SSP_ONWARD :
|
||||
#endif
|
||||
LWSLCG_WSI_CLIENT
|
||||
#if defined(LWS_WITH_SERVER)
|
||||
)
|
||||
#endif
|
||||
],
|
||||
#else
|
||||
LWSLCG_WSI_CLIENT],
|
||||
#endif
|
||||
|
|
|
@ -697,9 +697,10 @@ lws_callback_http_dummy(struct lws *wsi, enum lws_callback_reasons reason,
|
|||
if (wsi->mux_substream && !wsi->cgi_stdout_zero_length)
|
||||
lws_write(wsi, (unsigned char *)buf + LWS_PRE, 0,
|
||||
LWS_WRITE_HTTP_FINAL);
|
||||
|
||||
#if defined(LWS_WITH_SERVER)
|
||||
if (lws_http_transaction_completed(wsi))
|
||||
return -1;
|
||||
#endif
|
||||
return 0;
|
||||
|
||||
case LWS_CALLBACK_CGI_STDIN_DATA: /* POST body for stdin */
|
||||
|
|
|
@ -156,6 +156,7 @@ typedef struct lws_ss_handle {
|
|||
uint8_t tsi; /**< service thread idx, usually 0 */
|
||||
uint8_t subseq; /**< emulate SOM tracking */
|
||||
uint8_t txn_ok; /**< 1 = transaction was OK */
|
||||
uint8_t prev_ss_state;
|
||||
|
||||
uint8_t txn_resp_set:1; /**< user code set one */
|
||||
uint8_t txn_resp_pending:1; /**< we have yet to send */
|
||||
|
@ -300,6 +301,8 @@ typedef struct lws_sspc_handle {
|
|||
uint8_t rideshare_ofs[4];
|
||||
uint8_t rsidx;
|
||||
|
||||
uint8_t prev_ss_state;
|
||||
|
||||
uint8_t conn_req_state:2;
|
||||
uint8_t destroying:1;
|
||||
uint8_t non_wsi:1;
|
||||
|
@ -457,6 +460,13 @@ struct lws_vhost *
|
|||
lws_ss_policy_ref_trust_store(struct lws_context *context,
|
||||
const lws_ss_policy_t *pol, char doref);
|
||||
|
||||
lws_ss_state_return_t
|
||||
lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs,
|
||||
lws_ss_tx_ordinal_t flags);
|
||||
|
||||
int
|
||||
lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs);
|
||||
|
||||
void
|
||||
lws_proxy_clean_conn_ss(struct lws *wsi);
|
||||
|
||||
|
|
|
@ -514,9 +514,11 @@ secstream_h1(struct lws *wsi, enum lws_callback_reasons reason, void *user,
|
|||
h->seqstate = SSSEQ_CONNECTED;
|
||||
lws_sul_cancel(&h->sul);
|
||||
|
||||
r = lws_ss_event_helper(h, LWSSSCS_CONNECTED);
|
||||
if (r != LWSSSSRET_OK)
|
||||
return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h);
|
||||
if (h->prev_ss_state != LWSSSCS_CONNECTED) {
|
||||
r = lws_ss_event_helper(h, LWSSSCS_CONNECTED);
|
||||
if (r != LWSSSSRET_OK)
|
||||
return _lws_ss_handle_state_ret_CAN_DESTROY_HANDLE(r, wsi, &h);
|
||||
}
|
||||
|
||||
/*
|
||||
* Since it's an http transaction we initiated... this is
|
||||
|
|
|
@ -19,6 +19,22 @@
|
|||
*/
|
||||
#include <private-lib-core.h>
|
||||
|
||||
lws_ss_state_return_t
|
||||
lws_sspc_event_helper(lws_sspc_handle_t *h, lws_ss_constate_t cs,
|
||||
lws_ss_tx_ordinal_t flags)
|
||||
{
|
||||
if (!h)
|
||||
return LWSSSSRET_OK;
|
||||
|
||||
if (lws_ss_check_next_state(&h->prev_ss_state, cs))
|
||||
return LWSSSSRET_DESTROY_ME;
|
||||
|
||||
if (!h->ssi.state)
|
||||
return LWSSSSRET_OK;
|
||||
|
||||
return h->ssi.state((void *)((uint8_t *)&h[1]), NULL, cs, flags);
|
||||
}
|
||||
|
||||
static void
|
||||
lws_sspc_sul_retry_cb(lws_sorted_usec_list_t *sul)
|
||||
{
|
||||
|
@ -533,7 +549,6 @@ void
|
|||
lws_sspc_destroy(lws_sspc_handle_t **ph)
|
||||
{
|
||||
lws_sspc_handle_t *h;
|
||||
void *m;
|
||||
|
||||
lwsl_debug("%s\n", __func__);
|
||||
|
||||
|
@ -541,7 +556,6 @@ lws_sspc_destroy(lws_sspc_handle_t **ph)
|
|||
return;
|
||||
|
||||
h = *ph;
|
||||
m = (void *)((uint8_t *)&h[1]);
|
||||
|
||||
if (h->destroying)
|
||||
return;
|
||||
|
@ -549,7 +563,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph)
|
|||
h->destroying = 1;
|
||||
|
||||
if (h->ss_dangling_connected && h->ssi.state) {
|
||||
h->ssi.state(m, NULL, LWSSSCS_DISCONNECTED, 0);
|
||||
lws_sspc_event_helper(h, LWSSSCS_DISCONNECTED, 0);
|
||||
h->ss_dangling_connected = 0;
|
||||
}
|
||||
|
||||
|
@ -578,8 +592,7 @@ lws_sspc_destroy(lws_sspc_handle_t **ph)
|
|||
|
||||
lws_sspc_rxmetadata_destroy(h);
|
||||
|
||||
if (h->ssi.state)
|
||||
h->ssi.state(m, NULL, LWSSSCS_DESTROYING, 0);
|
||||
lws_sspc_event_helper(h, LWSSSCS_DESTROYING, 0);
|
||||
*ph = NULL;
|
||||
|
||||
__lws_lc_untag(&h->lc);
|
||||
|
|
|
@ -1240,6 +1240,11 @@ payload_ff:
|
|||
|
||||
h->creating_cb_done = 1;
|
||||
|
||||
if (lws_ss_check_next_state(&h->prev_ss_state, LWSSSCS_CREATING))
|
||||
return LWSSSSRET_DESTROY_ME;
|
||||
|
||||
h->prev_ss_state = (uint8_t)LWSSSCS_CREATING;
|
||||
|
||||
if (ssi->state) {
|
||||
n = ssi->state(client_pss_to_userdata(pss),
|
||||
NULL, LWSSSCS_CREATING, 0);
|
||||
|
@ -1391,8 +1396,14 @@ payload_ff:
|
|||
if (cs == LWSSSCS_DISCONNECTED)
|
||||
h->ss_dangling_connected = 0;
|
||||
|
||||
if (lws_ss_check_next_state(&h->prev_ss_state, cs))
|
||||
return LWSSSSRET_DESTROY_ME;
|
||||
|
||||
if (cs < LWSSSCS_USER_BASE)
|
||||
h->prev_ss_state = (uint8_t)cs;
|
||||
|
||||
n = ssi->state(client_pss_to_userdata(pss),
|
||||
NULL, (lws_ss_constate_t)par->ctr, par->flags);
|
||||
NULL, cs, par->flags);
|
||||
switch (n) {
|
||||
case LWSSSSRET_OK:
|
||||
break;
|
||||
|
|
|
@ -50,6 +50,7 @@ static const struct ss_pcols *ss_pcols[] = {
|
|||
};
|
||||
|
||||
static const char *state_names[] = {
|
||||
"(unset)",
|
||||
"LWSSSCS_CREATING",
|
||||
"LWSSSCS_DISCONNECTED",
|
||||
"LWSSSCS_UNREACHABLE",
|
||||
|
@ -68,6 +69,144 @@ static const char *state_names[] = {
|
|||
"LWSSSCS_SERVER_UPGRADE",
|
||||
};
|
||||
|
||||
/*
|
||||
* For each "current state", set bit offsets for valid "next states".
|
||||
*
|
||||
* Since there are complicated ways to arrive at state transitions like proxying
|
||||
* and asynchronous destruction etc, so we monitor the state transitions we are
|
||||
* giving the ss user code to ensure we never deliver illegal state transitions
|
||||
* (because we will assert if we have bugs that do it)
|
||||
*/
|
||||
|
||||
static const uint32_t ss_state_txn_validity[] = {
|
||||
|
||||
/* if we was last in this state... we can legally go to these states */
|
||||
|
||||
[0] = (1 << LWSSSCS_CREATING) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_CREATING] = (1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_POLL) |
|
||||
(1 << LWSSSCS_SERVER_UPGRADE) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_DISCONNECTED] = (1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_POLL) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_UNREACHABLE] = (1 << LWSSSCS_ALL_RETRIES_FAILED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_POLL) |
|
||||
(1 << LWSSSCS_CONNECTING) |
|
||||
/* win conn failure > retry > succ */
|
||||
(1 << LWSSSCS_CONNECTED) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_AUTH_FAILED] = (1 << LWSSSCS_ALL_RETRIES_FAILED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_CONNECTED] = (1 << LWSSSCS_SERVER_UPGRADE) |
|
||||
(1 << LWSSSCS_AUTH_FAILED) |
|
||||
(1 << LWSSSCS_QOS_ACK_REMOTE) |
|
||||
(1 << LWSSSCS_QOS_NACK_REMOTE) |
|
||||
(1 << LWSSSCS_QOS_ACK_LOCAL) |
|
||||
(1 << LWSSSCS_QOS_NACK_LOCAL) |
|
||||
(1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_CONNECTING] = (1 << LWSSSCS_UNREACHABLE) |
|
||||
(1 << LWSSSCS_AUTH_FAILED) |
|
||||
(1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_CONNECTED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_DESTROYING] = 0,
|
||||
|
||||
[LWSSSCS_POLL] = (1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_ALL_RETRIES_FAILED] = (1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_QOS_ACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_QOS_NACK_REMOTE] = (1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_QOS_ACK_LOCAL] = (1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_QOS_NACK_LOCAL] = (1 << LWSSSCS_DESTROYING) |
|
||||
(1 << LWSSSCS_TIMEOUT),
|
||||
|
||||
[LWSSSCS_TIMEOUT] = (1 << LWSSSCS_CONNECTING) |
|
||||
(1 << LWSSSCS_POLL) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_SERVER_TXN] = (1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
|
||||
[LWSSSCS_SERVER_UPGRADE] = (1 << LWSSSCS_SERVER_TXN) |
|
||||
(1 << LWSSSCS_TIMEOUT) |
|
||||
(1 << LWSSSCS_DISCONNECTED) |
|
||||
(1 << LWSSSCS_DESTROYING),
|
||||
};
|
||||
|
||||
int
|
||||
lws_ss_check_next_state(uint8_t *prevstate, lws_ss_constate_t cs)
|
||||
{
|
||||
if (cs >= LWSSSCS_USER_BASE)
|
||||
/*
|
||||
* we can't judge user states, leave the old state and
|
||||
* just wave them through
|
||||
*/
|
||||
return 0;
|
||||
|
||||
if (cs >= LWS_ARRAY_SIZE(ss_state_txn_validity)) {
|
||||
/* we don't recognize this state as usable */
|
||||
lwsl_err("%s: bad new state %u\n", __func__, cs);
|
||||
assert(0);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (*prevstate >= LWS_ARRAY_SIZE(ss_state_txn_validity)) {
|
||||
/* existing state is broken */
|
||||
lwsl_err("%s: bad existing state %u\n", __func__,
|
||||
(unsigned int)*prevstate);
|
||||
assert(0);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (ss_state_txn_validity[*prevstate] & (1u << cs)) {
|
||||
/* this is explicitly allowed, update old state to new */
|
||||
*prevstate = (uint8_t)cs;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
lwsl_err("%s: transition from %s -> %s is illegal\n", __func__,
|
||||
lws_ss_state_name((int)*prevstate),
|
||||
lws_ss_state_name((int)cs));
|
||||
|
||||
assert(0);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
const char *
|
||||
lws_ss_state_name(int state)
|
||||
{
|
||||
|
@ -88,6 +227,9 @@ lws_ss_event_helper(lws_ss_handle_t *h, lws_ss_constate_t cs)
|
|||
if (!h)
|
||||
return LWSSSSRET_OK;
|
||||
|
||||
if (lws_ss_check_next_state(&h->prev_ss_state, cs))
|
||||
return LWSSSSRET_DESTROY_ME;
|
||||
|
||||
if (cs == LWSSSCS_CONNECTED)
|
||||
h->ss_dangling_connected = 1;
|
||||
if (cs == LWSSSCS_DISCONNECTED)
|
||||
|
@ -707,7 +849,7 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi,
|
|||
*/
|
||||
if (!(ssi->flags & LWSSSINFLAGS_PROXIED) &&
|
||||
pol == &pol_smd) {
|
||||
lws_ss_state_return_t r;
|
||||
|
||||
/*
|
||||
* So he has asked to be wired up to SMD over a SS link.
|
||||
* Register him as an smd participant in his own right.
|
||||
|
@ -722,12 +864,6 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi,
|
|||
if (!h->u.smd.smd_peer)
|
||||
goto late_bail;
|
||||
lwsl_info("%s: registered SS SMD\n", __func__);
|
||||
r = lws_ss_event_helper(h, LWSSSCS_CONNECTING);
|
||||
if (r)
|
||||
return r;
|
||||
r = lws_ss_event_helper(h, LWSSSCS_CONNECTED);
|
||||
if (r)
|
||||
return r;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -798,6 +934,11 @@ lws_ss_create(struct lws_context *context, int tsi, const lws_ss_info_t *ssi,
|
|||
*/
|
||||
vho->ss_handle = h;
|
||||
|
||||
r = lws_ss_event_helper(h, LWSSSCS_CREATING);
|
||||
lwsl_info("%s: CREATING returned status %d\n", __func__, (int)r);
|
||||
if (r == LWSSSSRET_DESTROY_ME)
|
||||
goto late_bail;
|
||||
|
||||
lwsl_notice("%s: created server %s\n", __func__,
|
||||
h->policy->streamtype);
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue