diff --git a/server/include/socket.h b/server/include/socket.h index baa5e3e0a..0b8b63bf3 100644 --- a/server/include/socket.h +++ b/server/include/socket.h @@ -16,6 +16,8 @@ struct socket { /** The socket descriptor */ int sd; + /** The socket descriptor for an established TCP connection */ + int sd2; /** Socket mark for netem, routing and filtering */ int mark; diff --git a/server/src/node.c b/server/src/node.c index e3aa6789b..afd02ea41 100644 --- a/server/src/node.c +++ b/server/src/node.c @@ -77,20 +77,20 @@ int node_start(struct node *n) int node_start_defer(struct node *n) { - switch (node_type(n)) { - case TCPD: - info("Wait for incoming TCP connection from node '%s'...", n->name); - listen(n->socket->sd, 1); - n->socket->sd = accept(n->socket->sd, NULL, NULL); - break; - - case TCP: - info("Connect with TCP to remote node '%s'", n->name); - connect(n->socket->sd, (struct sockaddr *) &n->socket->remote, sizeof(n->socket->remote)); - break; + int ret; - default: - break; + if (node_type(n) == TCPD) { + info("Wait for incoming TCP connection from node '%s'...", n->name); + + ret = listen(n->socket->sd2, 1); + if (ret < 0) + serror("Failed to listen on socket for node '%s'", n->name); + + ret = accept(n->socket->sd2, NULL, NULL); + if (ret < 0) + serror("Failed to accept on socket for node '%s'", n->name); + + n->socket->sd = ret; } return 0; diff --git a/server/src/socket.c b/server/src/socket.c index 770d91ca0..eacbe572d 100644 --- a/server/src/socket.c +++ b/server/src/socket.c @@ -43,15 +43,19 @@ int socket_print(struct node *n, char *buf, int len) int socket_open(struct node *n) { struct socket *s = n->socket; - int af = s->local.ss_family; - + struct sockaddr_in *sin = (struct sockaddr_in *) &s->local; + struct sockaddr_ll *sll = (struct sockaddr_ll *) &s->local; + int ret; + + s->sd = s->sd2 = -1; + /* Create socket */ switch (node_type(n)) { case TCPD: - case TCP: s->sd = socket(af, SOCK_STREAM, 0); break; - case UDP: s->sd = socket(af, SOCK_DGRAM, 0); break; - case IP: s->sd = socket(af, SOCK_RAW, IPPROTO_S2SS); break; - case IEEE_802_3:s->sd = socket(af, SOCK_DGRAM, ETH_P_S2SS); break; + case TCP: s->sd = socket(sin->sin_family, SOCK_STREAM, IPPROTO_TCP); break; + case UDP: s->sd = socket(sin->sin_family, SOCK_DGRAM, IPPROTO_UDP); break; + case IP: s->sd = socket(sin->sin_family, SOCK_RAW, ntohs(sin->sin_port)); break; + case IEEE_802_3:s->sd = socket(sin->sin_family, SOCK_DGRAM, sll->sll_protocol); break; default: error("Invalid socket type!"); } @@ -60,8 +64,20 @@ int socket_open(struct node *n) serror("Failed to create socket"); /* Bind socket for receiving */ - if (bind(s->sd, (struct sockaddr *) &s->local, sizeof(s->local))) + ret = bind(s->sd, (struct sockaddr *) &s->local, sizeof(s->local)); + if (ret < 0) serror("Failed to bind socket"); + + /* Connect socket for sending */ + if (node_type(n) == TCPD) { + /* Listening TCP sockets will be connected later by calling accept() */ + s->sd2 = s->sd; + } + else if (node_type(n) != IEEE_802_3) { + ret = connect(s->sd, (struct sockaddr *) &s->remote, sizeof(s->remote)); + if (ret < 0) + serror("Failed to connect socket"); + } /* Determine outgoing interface */ int index = if_getegress((struct sockaddr *) &s->remote); @@ -103,19 +119,28 @@ int socket_open(struct node *n) int socket_close(struct node *n) { - return close(n->socket->sd); + struct socket *s = n->socket; + + if (s->sd >= 0) { + debug(5, "closing sd = %u", s->sd); + close(s->sd); + } + + if (s->sd2 >= 0) { + debug(5, "closing sd2 = %u", s->sd2); + close(s->sd2); + } + + return 0; } int socket_read(struct node* n, struct msg *m) { - /** @todo Fix this for multiple paths calling msg_recv. */ - /* Receive message from socket */ - if (recv(n->socket->sd, m, sizeof(struct msg), 0) < 0) { - if (errno == EINTR) - return -EINTR; - - } + int ret = recv(n->socket->sd, m, sizeof(struct msg), 0); + if (ret == 0) + error("Remote node '%s' closed the connection", n->name); + else if (ret < 0) serror("Failed recv"); /* Convert headers to host byte order */ @@ -133,12 +158,18 @@ int socket_read(struct node* n, struct msg *m) int socket_write(struct node* n, struct msg *m) { + struct socket *s = n->socket; + int ret; + /* Convert headers to network byte order */ m->sequence = htons(m->sequence); - if (sendto(n->socket->sd, m, MSG_LEN(m->length), 0, - (struct sockaddr *) &n->socket->remote, - sizeof(struct sockaddr_in)) < 0) + if (node_type(n) == IEEE_802_3) + ret = sendto(s->sd, m, MSG_LEN(m->length), 0, (struct sockaddr *) &s->remote, sizeof(s->remote)); + else + ret = send(s->sd, m, MSG_LEN(m->length), 0); + + if (ret < 0) serror("Failed sendto"); debug(10, "Message sent to node '%s': version=%u, type=%u, endian=%u, length=%u, sequence=%u", @@ -160,7 +191,7 @@ int socket_print_addr(char *buf, int len, struct sockaddr *sa) struct sockaddr_ll *sll = (struct sockaddr_ll *) sa; char ifname[IF_NAMESIZE]; - return snprintf(buf, len, "%s%%%s:%hu", + return snprintf(buf, len, "%s%%%s:%#hx", ether_ntoa((struct ether_addr *) &sll->sll_addr), if_indextoname(sll->sll_ifindex, ifname), ntohs(sll->sll_protocol)); @@ -185,7 +216,8 @@ int socket_parse_addr(const char *addr, struct sockaddr *sa, enum node_type type /* Split string */ char *node = strtok(copy, "%"); - char *ifname = strtok(NULL, "\0"); + char *ifname = strtok(NULL, ":"); + char *proto = strtok(NULL, "\0"); /* Parse link layer (MAC) address */ struct ether_addr *mac = ether_aton(node); @@ -194,15 +226,14 @@ int socket_parse_addr(const char *addr, struct sockaddr *sa, enum node_type type memcpy(&sll->sll_addr, &mac->ether_addr_octet, 6); - sll->sll_protocol = ETH_P_S2SS; + sll->sll_protocol = htons((proto) ? strtol(proto, NULL, 0) : ETH_P_S2SS); sll->sll_halen = 6; sll->sll_family = AF_PACKET; sll->sll_ifindex = if_nametoindex(ifname); ret = 0; } - else { - //struct sockaddr_in *sin = (struct sockaddr_in *) sa; + else { /* Format: "192.168.0.10:12001" */ struct addrinfo hint = { .ai_flags = flags, .ai_family = AF_UNSPEC @@ -220,8 +251,9 @@ int socket_parse_addr(const char *addr, struct sockaddr *sa, enum node_type type switch (type) { case IP: - hint.ai_socktype = 0; - hint.ai_protocol = IPPROTO_S2SS; + hint.ai_socktype = SOCK_RAW; + hint.ai_protocol = (service) ? strtol(service, NULL, 0) : IPPROTO_S2SS; + hint.ai_flags |= AI_NUMERICSERV; break; case TCPD: @@ -235,16 +267,23 @@ int socket_parse_addr(const char *addr, struct sockaddr *sa, enum node_type type hint.ai_protocol = IPPROTO_UDP; break; - case INVALID: default: error("Invalid address type"); } /* Lookup address */ struct addrinfo *result; - ret = getaddrinfo(node, service, &hint, &result); + ret = getaddrinfo(node, (type == IP) ? NULL : service, &hint, &result); if (!ret) { - memcpy(sa, result->ai_addr, result->ai_addrlen); + + if (type == IP) { + /* We mis-use the sin_port field to store the IP protocol number on RAW sockets */ + struct sockaddr_in *sin = (struct sockaddr_in *) result->ai_addr; + sin->sin_port = htons(result->ai_protocol); + } + + memcpy(sa, result->ai_addr, result->ai_addrlen); + freeaddrinfo(result); } }