/* Netlink related functions.
 *
 * VILLASnode uses libnl3 to talk to the Linux kernel to gather networking related information
 *
 * Author: Steffen Vogel <post@steffenvogel.de>
 * SPDX-FileCopyrightText: 2014-2023 Institute for Automation of Complex Power Systems, RWTH Aachen University
 * SPDX-License-Identifier: Apache-2.0
 */

#include <linux/if_packet.h>

#include <netlink/route/link.h>
#include <netlink/route/route.h>

#include <villas/exceptions.hpp>
#include <villas/kernel/nl.hpp>
#include <villas/utils.hpp>

// Singleton for global netlink socket
static struct nl_sock *sock = nullptr;

using namespace villas;
using namespace villas::kernel::nl;

struct nl_sock *villas::kernel::nl::init() {
  int ret;

  if (!sock) {
    // Create connection to netlink
    sock = nl_socket_alloc();
    if (!sock)
      throw MemoryAllocationError();

    ret = nl_connect(sock, NETLINK_ROUTE);
    if (ret)
      throw RuntimeError("Failed to connect to kernel: {}", nl_geterror(ret));

    // Fill some caches
    struct nl_cache *cache;
    ret = rtnl_link_alloc_cache(sock, AF_UNSPEC, &cache);
    if (ret)
      throw RuntimeError("Failed to get list of interfaces: {}",
                         nl_geterror(ret));

    nl_cache_mngt_provide(cache);
  }

  return sock;
}

void villas::kernel::nl::shutdown() {
  nl_close(sock);
  nl_socket_free(sock);

  sock = nullptr;
}

static int egress_cb(struct nl_msg *msg, void *arg) {
  struct rtnl_route **route = (struct rtnl_route **)arg;

  if (rtnl_route_parse(nlmsg_hdr(msg), route))
    return NL_SKIP;

  return NL_STOP;
}

int villas::kernel::nl::get_egress(struct nl_addr *addr) {
  int ret;
  struct nl_sock *sock = nl::init();
  struct nl_cb *cb;
  struct nl_msg *msg = nlmsg_alloc_simple(RTM_GETROUTE, 0);
  struct rtnl_route *route = nullptr;

  // Build message
  struct rtmsg rmsg = {
      .rtm_family = (unsigned char)nl_addr_get_family(addr),
      .rtm_dst_len = (unsigned char)nl_addr_get_prefixlen(addr),
  };

  ret = nlmsg_append(msg, &rmsg, sizeof(rmsg), NLMSG_ALIGNTO);
  if (ret)
    goto out;

  ret = nla_put_addr(msg, RTA_DST, addr);
  if (ret)
    goto out;

  // Send message
  ret = nl_send_auto(sock, msg);
  if (ret < 0)
    goto out;

  // Hook into receive chain
  cb = nl_cb_alloc(NL_CB_CUSTOM);
  nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, egress_cb, &route);

  // Receive message
  nl_recvmsgs_report(sock, cb);
  nl_wait_for_ack(sock);

  // Check result
  if (!route || rtnl_route_get_nnexthops(route) != 1) {
    ret = -1;
    goto out2;
  }

  ret = rtnl_route_nh_get_ifindex(rtnl_route_nexthop_n(route, 0));

  rtnl_route_put(route);

out2:
  nl_cb_put(cb);
out:
  nlmsg_free(msg);

  return ret;
}

struct rtnl_link *villas::kernel::nl::get_egress_link(struct sockaddr *sa) {
  int ifindex = -1;

  switch (sa->sa_family) {
  case AF_INET:
  case AF_INET6: {
    struct sockaddr_in *sin = (struct sockaddr_in *)sa;
    struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sa;

    struct nl_addr *addr =
        (sa->sa_family == AF_INET)
            ? nl_addr_build(sin->sin_family, &sin->sin_addr.s_addr,
                            sizeof(sin->sin_addr.s_addr))
            : nl_addr_build(sin6->sin6_family, sin6->sin6_addr.s6_addr,
                            sizeof(sin6->sin6_addr));

    ifindex = nl::get_egress(addr);
    nl_addr_put(addr);
    if (ifindex < 0)
      throw RuntimeError("Netlink error: {}", nl_geterror(ifindex));
    break;
  }

  case AF_PACKET: {
    struct sockaddr_ll *sll = (struct sockaddr_ll *)sa;

    ifindex = sll->sll_ifindex;
    break;
  }
  }

  struct nl_cache *cache = nl_cache_mngt_require("route/link");

  return rtnl_link_get(cache, ifindex);
}