diff --git a/include/villas/nodes/socket.h b/include/villas/nodes/socket.h index 9714bf834..252a9ebbf 100644 --- a/include/villas/nodes/socket.h +++ b/include/villas/nodes/socket.h @@ -32,6 +32,7 @@ #pragma once #include +#include #include #include "node.h" @@ -53,27 +54,27 @@ union sockaddr_union { struct sockaddr_in sin; struct sockaddr_in6 sin6; struct sockaddr_ll sll; + struct sockaddr_un sun; }; struct socket { - int sd; /**> The socket descriptor */ - int mark; /**> Socket mark for netem, routing and filtering */ + int sd; /**< The socket descriptor */ + int mark; /**< Socket mark for netem, routing and filtering */ + int verify_source; /**< Verify the source address of incoming packets against socket::remote. */ enum { SOCKET_ENDIAN_LITTLE, SOCKET_ENDIAN_BIG - } endian; /** Endianness of the data sent/received by the node */ + } endian; /**< Endianness of the data sent/received by the node */ - enum socket_layer layer; /**> The OSI / IP layer which should be used for this socket */ - enum socket_header header; /**> Payload header type */ + enum socket_layer layer; /**< The OSI / IP layer which should be used for this socket */ + enum socket_header header; /**< Payload header type */ - union sockaddr_union local; /**> Local address of the socket */ - union sockaddr_union remote; /**> Remote address of the socket */ + union sockaddr_union local; /**< Local address of the socket */ + union sockaddr_union remote; /**< Remote address of the socket */ - struct rtnl_qdisc *tc_qdisc; /**> libnl3: Network emulator queuing discipline */ - struct rtnl_cls *tc_classifier; /**> libnl3: Firewall mark classifier */ - - struct socket *next; /* Linked list _per_interface_ */ + struct rtnl_qdisc *tc_qdisc; /**< libnl3: Network emulator queuing discipline */ + struct rtnl_cls *tc_classifier; /**< libnl3: Firewall mark classifier */ }; @@ -127,4 +128,6 @@ char * socket_print_addr(struct sockaddr *saddr); */ int socket_parse_addr(const char *str, struct sockaddr *sa, enum socket_layer layer, int flags); -/** @} */ \ No newline at end of file +int socket_compare_addr(struct sockaddr *x, struct sockaddr *y); + +/** @} */ diff --git a/lib/nodes/socket.c b/lib/nodes/socket.c index 6025fff45..cd1fde8fc 100644 --- a/lib/nodes/socket.c +++ b/lib/nodes/socket.c @@ -330,8 +330,11 @@ static int socket_read_villas(struct node *n, struct sample *smps[], unsigned cn char data[MSG_MAX_PACKET_LEN]; ssize_t bytes; + struct sockaddr_storage src; + socklen_t srclen = sizeof(src); + /* Receive message from socket */ - bytes = recv(s->sd, data, sizeof(data), 0); + bytes = recvfrom(s->sd, data, sizeof(data), 0, (struct sockaddr *) &src, &srclen); if (bytes == 0) error("Remote node %s closed the connection", node_name(n)); else if (bytes < 0) @@ -340,11 +343,19 @@ static int socket_read_villas(struct node *n, struct sample *smps[], unsigned cn warn("Received invalid packet for node %s", node_name(n)); return 0; } - + + if (s->verify_source && socket_compare_addr((struct sockaddr *) &src, (struct sockaddr *) &s->remote) != 0) { + char *buf = socket_print_addr((struct sockaddr *) &src); + + warn("Received packet from unauthorized source: %s", buf); + + free(buf); + } + ret = msg_buffer_to_samples(smps, cnt, data, bytes); if (ret < 0) warn("Received invalid packet from node: %s", node_name(n)); - + return ret; } @@ -490,6 +501,9 @@ int socket_parse(struct node *n, config_setting_t *cfg) if (!config_setting_lookup_string(cfg, "local", &local)) cerror(cfg, "Missing local address for node %s", node_name(n)); + if (!config_setting_lookup_bool(cfg, "verify_source", &s->verify_source)) + s->verify_source = 0; + ret = socket_parse_addr(local, (struct sockaddr *) &s->local, s->layer, AI_PASSIVE); if (ret) { cerror(cfg, "Failed to resolve local address '%s' of node %s: %s", @@ -649,6 +663,47 @@ int socket_parse_addr(const char *addr, struct sockaddr *saddr, enum socket_laye return ret; } +int socket_compare_addr(struct sockaddr *x, struct sockaddr *y) +{ +#define CMP(a, b) if (a != b) return a < b ? -1 : 1 + + union sockaddr_union *xu = (void *) x, *yu = (void *) y; + + CMP(x->sa_family, y->sa_family); + + switch (x->sa_family) { + case AF_UNIX: + return strcmp(xu->sun.sun_path, yu->sun.sun_path); + + case AF_INET: + CMP(ntohl(xu->sin.sin_addr.s_addr), ntohl(yu->sin.sin_addr.s_addr)); + CMP(ntohs(xu->sin.sin_port), ntohs(yu->sin.sin_port)); + + return 0; + + case AF_INET6: + CMP(ntohs(xu->sin6.sin6_port), ntohs(yu->sin6.sin6_port)); +// CMP(xu->sin6.sin6_flowinfo, yu->sin6.sin6_flowinfo); +// CMP(xu->sin6.sin6_scope_id, yu->sin6.sin6_scope_id); + + return memcmp(xu->sin6.sin6_addr.s6_addr, yu->sin6.sin6_addr.s6_addr, sizeof(xu->sin6.sin6_addr.s6_addr)); + + case AF_PACKET: + CMP(xu->sll.sll_protocol, yu->sll.sll_protocol); + CMP(xu->sll.sll_ifindex, yu->sll.sll_ifindex); +// CMP(xu->sll.sll_pkttype, yu->sll.sll_pkttype); +// CMP(xu->sll.sll_hatype, yu->sll.sll_hatype); +// CMP(xu->sll.sll_halen, yu->sll.sll_halen); + + return memcmp(xu->sll.sll_addr, yu->sll.sll_addr, sizeof(xu->sll.sll_addr)); + + default: + return -1; + } + +#undef CMP +} + static struct plugin p = { .name = "socket", .description = "BSD network sockets",