From 5df3fabcd1bd063138a2b81d1dd6e2f1b6a9f0c2 Mon Sep 17 00:00:00 2001
From: Jitpanu Maneeratpongsuk <jitpanu.maneeratpongsuk@rwth-aachen.de>
Date: Thu, 19 Dec 2024 15:09:19 +0000
Subject: [PATCH] feat: Add basic tcp connection to socket node type

Signed-off-by: Jitpanu Maneeratpongsuk <jitpanu.maneeratpongsuk@rwth-aachen.de>
---
 include/villas/nodes/socket.hpp |  1 +
 include/villas/socket_addr.hpp  |  2 +-
 lib/nodes/socket.cpp            | 87 +++++++++++++++++++++++++++++++--
 lib/socket_addr.cpp             |  6 +++
 4 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/include/villas/nodes/socket.hpp b/include/villas/nodes/socket.hpp
index 8882a607d..671d508f4 100644
--- a/include/villas/nodes/socket.hpp
+++ b/include/villas/nodes/socket.hpp
@@ -22,6 +22,7 @@ class NodeCompat;
 
 struct Socket {
   int sd; // The socket descriptor
+  int clt_sd; // TCP client socket descriptor
   int verify_source; // Verify the source address of incoming packets against socket::remote.
 
   enum SocketLayer
diff --git a/include/villas/socket_addr.hpp b/include/villas/socket_addr.hpp
index d3947a171..34dbe9bc9 100644
--- a/include/villas/socket_addr.hpp
+++ b/include/villas/socket_addr.hpp
@@ -35,7 +35,7 @@ union sockaddr_union {
 namespace villas {
 namespace node {
 
-enum class SocketLayer { ETH, IP, UDP, UNIX };
+enum class SocketLayer { ETH, IP, UDP, UNIX, TCP_CLIENT, TCP_SERVER};
 
 /* Generate printable socket address depending on the address family
  *
diff --git a/lib/nodes/socket.cpp b/lib/nodes/socket.cpp
index e7e56ed08..784441b92 100644
--- a/lib/nodes/socket.cpp
+++ b/lib/nodes/socket.cpp
@@ -28,6 +28,9 @@
 #include <villas/kernel/nl.hpp>
 #endif // WITH_NETEM
 
+#define MAX_CONNECTION_RETRIES 5
+#define RETRIES_DELAY 2
+
 using namespace villas;
 using namespace villas::utils;
 using namespace villas::node;
@@ -97,6 +100,11 @@ char *villas::node::socket_print(NodeCompat *n) {
   case SocketLayer::UNIX:
     layer = "unix";
     break;
+
+  case SocketLayer::TCP_SERVER:
+  case SocketLayer::TCP_CLIENT:
+    layer = "tcp";
+    break;
   }
 
   char *local = socket_print_addr((struct sockaddr *)&s->in.saddr);
@@ -195,6 +203,11 @@ int villas::node::socket_start(NodeCompat *n) {
     s->sd = socket(s->in.saddr.sa.sa_family, SOCK_DGRAM, 0);
     break;
 
+  case SocketLayer::TCP_SERVER:
+  case SocketLayer::TCP_CLIENT:
+    s->sd = socket(s->in.saddr.sa.sa_family, SOCK_STREAM, 0);
+    break;
+
   default:
     throw RuntimeError("Invalid socket type!");
   }
@@ -233,9 +246,43 @@ int villas::node::socket_start(NodeCompat *n) {
     addrlen = sizeof(s->in.saddr);
   }
 
-  ret = bind(s->sd, (struct sockaddr *)&s->in.saddr, addrlen);
-  if (ret < 0)
-    throw SystemError("Failed to bind socket");
+  if (s->layer == SocketLayer::TCP_CLIENT) {
+    //Attempt to connect to TCP server
+    int retries = 0;
+    while (retries < MAX_CONNECTION_RETRIES) {
+      n->logger->info("Attempting({}) to connect to server..", retries + 1);
+      ret = connect(s->sd, (struct sockaddr *)&s->out.saddr, addrlen);
+      if (ret == 0) {
+        break;
+      } else {
+        retries++;
+        if (retries < MAX_CONNECTION_RETRIES) {
+          sleep(RETRIES_DELAY);
+        }
+      }
+    }
+
+  } else {
+    ret = bind(s->sd, (struct sockaddr *)&s->in.saddr, addrlen);
+  }
+
+  if (ret < 0) {
+    if (s->layer == SocketLayer::TCP_CLIENT) {
+      throw SystemError("Failed to connect to TCP server");
+    } else {
+      throw SystemError("Failed to bind socket");
+    }
+  }
+
+  //TCP Server listen for client connection
+  if (s->layer == SocketLayer::TCP_SERVER) {
+    listen(s->sd, 5);
+    //Accept client connection and get client socket descriptor
+    s->clt_sd = accept(s->sd, nullptr, nullptr);
+    if (s->clt_sd < 0) {
+      throw SystemError("Failed to accept connection");
+    }
+  }
 
   if (s->multicast.enabled) {
     ret = setsockopt(s->sd, IPPROTO_IP, IP_MULTICAST_LOOP, &s->multicast.loop,
@@ -258,6 +305,8 @@ int villas::node::socket_start(NodeCompat *n) {
   int prio;
   switch (s->layer) {
   case SocketLayer::UDP:
+  case SocketLayer::TCP_SERVER:
+  case SocketLayer::TCP_CLIENT:
   case SocketLayer::IP:
     prio = IPTOS_LOWDELAY;
     if (setsockopt(s->sd, IPPROTO_IP, IP_TOS, &prio, sizeof(prio)))
@@ -316,7 +365,12 @@ int villas::node::socket_stop(NodeCompat *n) {
   }
 
   if (s->sd >= 0) {
+    //Close client socket descriptor
+    if (s->layer == SocketLayer::TCP_SERVER)
+      close(s->clt_sd);
+
     ret = close(s->sd);
+
     if (ret)
       return ret;
   }
@@ -340,7 +394,17 @@ int villas::node::socket_read(NodeCompat *n, struct Sample *const smps[],
   socklen_t srclen = sizeof(src);
 
   // Receive next sample
-  bytes = recvfrom(s->sd, s->in.buf, s->in.buflen, 0, &src.sa, &srclen);
+
+  if (s->layer == SocketLayer::TCP_CLIENT) {
+    //Receive data from server
+    bytes = recv(s->sd, s->in.buf, s->in.buflen, 0);
+  } else if (s->layer == SocketLayer::TCP_SERVER) {
+    //Receive data from client
+    bytes = recv(s->clt_sd, s->in.buf, s->in.buflen, 0);
+  } else {
+    bytes = recvfrom(s->sd, s->in.buf, s->in.buflen, 0, &src.sa, &srclen);
+  }
+
   if (bytes < 0) {
     if (errno == EINTR)
       return -1;
@@ -445,8 +509,17 @@ retry:
   }
 
 retry2:
-  bytes = sendto(s->sd, s->out.buf, wbytes, 0, (struct sockaddr *)&s->out.saddr,
+  if (s->layer == SocketLayer::TCP_CLIENT) {
+    //Send data to TCP server
+    bytes = send(s->sd, s->out.buf, wbytes, 0);
+  } else if (s->layer == SocketLayer::TCP_SERVER) {
+    //Send data to TCP client
+    bytes = send(s->clt_sd, s->out.buf, wbytes, 0);
+  } else {
+    bytes = sendto(s->sd, s->out.buf, wbytes, 0, (struct sockaddr *)&s->out.saddr,
                  addrlen);
+  }
+
   if (bytes < 0) {
     if ((errno == EPERM) || (errno == ENOENT && s->layer == SocketLayer::UNIX))
       n->logger->warn("Failed sendto(): {}", strerror(errno));
@@ -505,6 +578,10 @@ int villas::node::socket_parse(NodeCompat *n, json_t *json) {
       s->layer = SocketLayer::UDP;
     else if (!strcmp(layer, "unix") || !strcmp(layer, "local"))
       s->layer = SocketLayer::UNIX;
+    else if (!strcmp(layer, "tcp_client"))
+      s->layer = SocketLayer::TCP_CLIENT;
+    else if (!strcmp(layer, "tcp_server"))
+      s->layer = SocketLayer::TCP_SERVER;
     else
       throw SystemError("Invalid layer '{}'", layer);
   }
diff --git a/lib/socket_addr.cpp b/lib/socket_addr.cpp
index a8e4e5f4c..2da3bbc73 100644
--- a/lib/socket_addr.cpp
+++ b/lib/socket_addr.cpp
@@ -156,6 +156,12 @@ int villas::node::socket_parse_address(const char *addr, struct sockaddr *saddr,
       hint.ai_protocol = IPPROTO_UDP;
       break;
 
+    case SocketLayer::TCP_CLIENT:
+    case SocketLayer::TCP_SERVER:
+      hint.ai_socktype = SOCK_STREAM;
+      hint.ai_protocol = IPPROTO_TCP;
+      break;
+
     default:
       throw RuntimeError("Invalid address type");
     }