diff --git a/lib/client/client-handshake.c b/lib/client/client-handshake.c index 5062b4dd..a6b98fd6 100644 --- a/lib/client/client-handshake.c +++ b/lib/client/client-handshake.c @@ -103,7 +103,7 @@ lws_client_connect_2(struct lws *wsi) * to whatever we decided to connect to */ - lwsl_notice("%s: %p: address %s\n", __func__, wsi, ads); + lwsl_info("%s: %p: address %s\n", __func__, wsi, ads); n = lws_getaddrinfo46(wsi, ads, &result); @@ -717,12 +717,20 @@ lws_client_connect_via_info(struct lws_client_connect_info *i) struct lws *wsi; int v = SPEC_LATEST_SUPPORTED; const struct lws_protocols *p; + const char *local = i->protocol; if (i->context->requested_kill) return NULL; if (!i->context->protocol_init_done) lws_protocol_init(i->context); + /* + * If we have .local_protocol_name, use it to select the + * local protocol handler to bind to. Otherwise use .protocol if + * http[s]. + */ + if (i->local_protocol_name) + local = i->local_protocol_name; wsi = lws_zalloc(sizeof(struct lws), "client wsi"); if (wsi == NULL) @@ -765,10 +773,19 @@ lws_client_connect_via_info(struct lws_client_connect_info *i) wsi->protocol = &wsi->vhost->protocols[0]; - /* for http[s] connection, allow protocol selection by name */ - - if (i->method && i->vhost && i->protocol) { - p = lws_vhost_name_to_protocol(i->vhost, i->protocol); + /* + * 1) for http[s] connection, allow protocol selection by name + * 2) for ws[s], if local_protocol_name given also use it for + * local protocol binding... this defeats the server + * protocol negotiation if so + * + * Otherwise leave at protocols[0]... the server will tell us + * which protocol we are associated with since we can give it a + * list. + */ + if ((i->method || i->local_protocol_name) && wsi->vhost && local) { + lwsl_info("binding to %s\n", local); + p = lws_vhost_name_to_protocol(wsi->vhost, local); if (p) wsi->protocol = p; } diff --git a/lib/client/client.c b/lib/client/client.c index 13fe80bd..29e89e2b 100644 --- a/lib/client/client.c +++ b/lib/client/client.c @@ -816,7 +816,7 @@ lws_client_interpret_server_handshake(struct lws *wsi) len = lws_hdr_total_length(wsi, WSI_TOKEN_PROTOCOL); if (!len) { - lwsl_info("lws_client_int_s_hs: WSI_TOKEN_PROTOCOL is null\n"); + lwsl_info("%s: WSI_TOKEN_PROTOCOL is null\n", __func__); /* * no protocol name to work from, * default to first protocol @@ -842,7 +842,7 @@ lws_client_interpret_server_handshake(struct lws *wsi) } if (!okay) { - lwsl_err("lws_client_int_s_hs: got bad protocol %s\n", p); + lwsl_info("%s: got bad protocol %s\n", __func__, p); cce = "HS: PROTOCOL malformed"; goto bail2; } @@ -851,21 +851,47 @@ lws_client_interpret_server_handshake(struct lws *wsi) * identify the selected protocol struct and set it */ n = 0; - wsi->protocol = NULL; - while (wsi->vhost->protocols[n].callback && !wsi->protocol) { - if (strcmp(p, wsi->vhost->protocols[n].name) == 0) { + /* keep client connection pre-bound protocol */ + if (!(wsi->mode & LWSCM_FLAG_IMPLIES_CALLBACK_CLOSED_CLIENT_HTTP)) + wsi->protocol = NULL; + + while (wsi->vhost->protocols[n].callback) { + if (!wsi->protocol && + strcmp(p, wsi->vhost->protocols[n].name) == 0) { wsi->protocol = &wsi->vhost->protocols[n]; break; } n++; } - if (wsi->protocol == NULL) { - lwsl_err("lws_client_int_s_hs: fail protocol %s\n", p); - cce = "HS: Cannot match protocol"; - goto bail2; + if (!wsi->vhost->protocols[n].callback) { /* no match */ + /* if server, that's already fatal */ + if (!(wsi->mode & LWSCM_FLAG_IMPLIES_CALLBACK_CLOSED_CLIENT_HTTP)) { + lwsl_info("%s: fail protocol %s\n", __func__, p); + cce = "HS: Cannot match protocol"; + goto bail2; + } + + /* for client, find the index of our pre-bound protocol */ + + n = 0; + while (wsi->vhost->protocols[n].callback) { + if (strcmp(wsi->protocol->name, + wsi->vhost->protocols[n].name) == 0) { + wsi->protocol = &wsi->vhost->protocols[n]; + break; + } + n++; + } + + if (!wsi->vhost->protocols[n].callback) { + lwsl_err("Failed to match protocol %s\n", wsi->protocol->name); + goto bail2; + } } + lwsl_debug("Selected protocol %s\n", wsi->protocol->name); + check_extensions: /* * stitch protocol choice into the vh protocol linked list diff --git a/lib/libwebsockets.h b/lib/libwebsockets.h index 264dd994..3d4a4e2d 100644 --- a/lib/libwebsockets.h +++ b/lib/libwebsockets.h @@ -3321,6 +3321,13 @@ struct lws_client_connect_info { const char *iface; /**< NULL to allow routing on any interface, or interface name or IP * to bind the socket to */ + const char *local_protocol_name; + /**< NULL: .protocol is used both to select the local protocol handler + * to bind to and as the list of remote ws protocols we could + * accept. + * non-NULL: this protocol name is used to bind the connection to + * the local protocol handler. .protocol is used for the + * list of remote ws protocols we could accept */ /* Add new things just above here ---^ * This is part of the ABI, don't needlessly break compatibility