/* Node type: ZeroMQ.
 *
 * Author: Steffen Vogel <post@steffenvogel.de>
 * SPDX-FileCopyrightText: 2014-2023 Institute for Automation of Complex Power Systems, RWTH Aachen University
 * SPDX-License-Identifier: Apache-2.0
 */

#include <cstring>
#include <zmq.h>

#if ZMQ_VERSION_MAJOR < 4 || (ZMQ_VERSION_MAJOR == 4 && ZMQ_VERSION_MINOR <= 1)
#include <zmq_utils.h>
#endif

#include <villas/exceptions.hpp>
#include <villas/node_compat.hpp>
#include <villas/nodes/zeromq.hpp>
#include <villas/queue.h>
#include <villas/super_node.hpp>
#include <villas/utils.hpp>

using namespace villas;
using namespace villas::node;
using namespace villas::utils;

static void *context;

/*  Read one event off the monitor socket; return value and address
 * by reference, if not null, and event number by value.
 *
 * @returnval  -1 In case of error. */
static int get_monitor_event(void *monitor, int *value, char **address) {
  // First frame in message contains event number and value
  zmq_msg_t msg;
  zmq_msg_init(&msg);
  if (zmq_msg_recv(&msg, monitor, 0) == -1)
    return -1; // Interruped, presumably

  assert(zmq_msg_more(&msg));

  uint8_t *data = (uint8_t *)zmq_msg_data(&msg);
  uint16_t event = *(uint16_t *)(data);
  if (value)
    *value = *(uint32_t *)(data + 2);

  // Second frame in message contains event address
  zmq_msg_init(&msg);
  if (zmq_msg_recv(&msg, monitor, 0) == -1)
    return -1; // Interruped, presumably

  assert(!zmq_msg_more(&msg));

  if (address) {
    uint8_t *data = (uint8_t *)zmq_msg_data(&msg);
    size_t size = zmq_msg_size(&msg);
    *address = (char *)malloc(size + 1);
    memcpy(*address, data, size);
    *address[size] = 0;
  }

  return event;
}

int villas::node::zeromq_reverse(NodeCompat *n) {
  auto *z = n->getData<struct zeromq>();

  if (list_length(&z->out.endpoints) != 1 || list_length(&z->in.endpoints) != 1)
    return -1;

  char *subscriber = (char *)list_first(&z->in.endpoints);
  char *publisher = (char *)list_first(&z->out.endpoints);

  list_set(&z->in.endpoints, 0, publisher);
  list_set(&z->out.endpoints, 0, subscriber);

  return 0;
}

int villas::node::zeromq_init(NodeCompat *n) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  z->out.bind = 1;
  z->in.bind = 0;

  z->curve.enabled = false;
  z->ipv6 = 0;

  z->in.pending = 0;
  z->out.pending = 0;

  ret = list_init(&z->in.endpoints);
  if (ret)
    return ret;

  ret = list_init(&z->out.endpoints);
  if (ret)
    return ret;

  z->formatter = nullptr;

  return 0;
}

static int zeromq_parse_endpoints(json_t *json_ep, struct List *epl) {
  json_t *json_val;
  size_t i;
  const char *ep;

  switch (json_typeof(json_ep)) {
  case JSON_ARRAY:
    json_array_foreach(json_ep, i, json_val) {
      ep = json_string_value(json_val);
      if (!ep)
        throw ConfigError(json_val, "node-config-node-publish",
                          "All 'publish' settings must be strings");

      list_push(epl, strdup(ep));
    }
    break;

  case JSON_STRING:
    ep = json_string_value(json_ep);
    list_push(epl, strdup(ep));
    break;

  default:
    return -1;
  }

  return 0;
}

int villas::node::zeromq_parse(NodeCompat *n, json_t *json) {
  auto *z = n->getData<struct zeromq>();

  int ret;
  const char *type = nullptr;
  const char *in_filter = nullptr;
  const char *out_filter = nullptr;

  json_error_t err;
  json_t *json_in_ep = nullptr;
  json_t *json_out_ep = nullptr;
  json_t *json_curve = nullptr;
  json_t *json_format = nullptr;

  ret = json_unpack_ex(json, &err, 0,
                       "{ s?: { s?: o, s?: s, s?: b }, s?: { s?: o, s?: s, s?: "
                       "b }, s?: o, s?: s, s?: b, s?: o }",
                       "in", "subscribe", &json_in_ep, "filter", &in_filter,
                       "bind", &z->in.bind, "out", "publish", &json_out_ep,
                       "filter", &out_filter, "bind", &z->out.bind, "curve",
                       &json_curve, "pattern", &type, "ipv6", &z->ipv6,
                       "format", &json_format);
  if (ret)
    throw ConfigError(json, err, "node-config-node-zeromq");

  z->in.filter = in_filter ? strdup(in_filter) : nullptr;
  z->out.filter = out_filter ? strdup(out_filter) : nullptr;

  // Format
  if (z->formatter)
    delete z->formatter;
  z->formatter = json_format ? FormatFactory::make(json_format)
                             : FormatFactory::make("villas.binary");
  if (!z->formatter)
    throw ConfigError(json_format, "node-config-node-zeromq-format",
                      "Invalid format configuration");

  if (json_out_ep) {
    ret = zeromq_parse_endpoints(json_out_ep, &z->out.endpoints);
    if (ret)
      throw ConfigError(json_out_ep, "node-config-node-zeromq-publish",
                        "Failed to parse list of publish endpoints");
  }

  if (json_in_ep) {
    ret = zeromq_parse_endpoints(json_in_ep, &z->in.endpoints);
    if (ret)
      throw ConfigError(json_out_ep, "node-config-node-zeromq-subscribe",
                        "Failed to parse list of subscribe endpoints");
  }

  if (json_curve) {
    const char *public_key, *secret_key;

    z->curve.enabled = true;

    ret = json_unpack_ex(json_curve, &err, 0, "{ s: s, s: s, s?: b }",
                         "public_key", &public_key, "secret_key", &secret_key,
                         "enabled", &z->curve.enabled);
    if (ret)
      throw ConfigError(json_curve, err, "node-config-node-zeromq-curve",
                        "Failed to parse setting 'curve'");

    if (strlen(secret_key) != 40)
      throw ConfigError(
          json_curve, err, "node-config-node-zeromq-curve",
          "Setting 'curve.secret_key' must be a Z85 encoded CurveZMQ key");

    if (strlen(public_key) != 40)
      throw ConfigError(
          json_curve, err, "node-config-node-zeromq-curve",
          "Setting 'curve.public_key' must be a Z85 encoded CurveZMQ key");

    memcpy(z->curve.server.public_key, public_key, 41);
    memcpy(z->curve.server.secret_key, secret_key, 41);
  }

  // TODO: We should fix this. Its mostly done.
  if (z->curve.enabled)
    throw ConfigError(json_curve, "node-config-zeromq-curve",
                      "CurveZMQ support is currently broken");

  if (type) {
    if (!strcmp(type, "pubsub"))
      z->pattern = zeromq::Pattern::PUBSUB;
#ifdef ZMQ_BUILD_DISH
    else if (!strcmp(type, "radiodish"))
      z->pattern = zeromq::Pattern::RADIODISH;
#endif
    else
      throw ConfigError(json, "node-config-node-zeromq-type",
                        "Invalid type for ZeroMQ node: {}", n->getNameShort());
  }

  return 0;
}

char *villas::node::zeromq_print(NodeCompat *n) {
  auto *z = n->getData<struct zeromq>();

  char *buf = nullptr;
  const char *pattern = nullptr;

  switch (z->pattern) {
  case zeromq::Pattern::PUBSUB:
    pattern = "pubsub";
    break;

#ifdef ZMQ_BUILD_DISH
  case zeromq::Pattern::RADIODISH:
    pattern = "radiodish";
    break;
#endif
  }

  strcatf(&buf,
          "pattern=%s, ipv6=%s, crypto=%s, in.bind=%s, out.bind=%s, "
          "in.subscribe=[ ",
          pattern, z->ipv6 ? "yes" : "no", z->curve.enabled ? "yes" : "no",
          z->in.bind ? "yes" : "no", z->out.bind ? "yes" : "no");

  for (size_t i = 0; i < list_length(&z->in.endpoints); i++) {
    char *ep = (char *)list_at(&z->in.endpoints, i);

    strcatf(&buf, "%s ", ep);
  }

  strcatf(&buf, "], out.publish=[ ");

  for (size_t i = 0; i < list_length(&z->out.endpoints); i++) {
    char *ep = (char *)list_at(&z->out.endpoints, i);

    strcatf(&buf, "%s ", ep);
  }

  strcatf(&buf, "]");

  if (z->in.filter)
    strcatf(&buf, ", in.filter=%s", z->in.filter);

  if (z->out.filter)
    strcatf(&buf, ", out.filter=%s", z->out.filter);

  return buf;
}

int villas::node::zeromq_check(NodeCompat *n) {
  auto *z = n->getData<struct zeromq>();

  if (list_length(&z->in.endpoints) == 0 && list_length(&z->out.endpoints) == 0)
    return -1;

  return 0;
}

int villas::node::zeromq_type_start(villas::node::SuperNode *sn) {
  context = zmq_ctx_new();

  return context == nullptr;
}

int villas::node::zeromq_type_stop() { return zmq_ctx_term(context); }

int villas::node::zeromq_start(NodeCompat *n) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  struct zeromq::Dir *dirs[] = {&z->out, &z->in};

  z->formatter->start(n->getInputSignals(false), ~(int)SampleFlags::HAS_OFFSET);

  switch (z->pattern) {
#ifdef ZMQ_BUILD_DISH
  case zeromq::Pattern::RADIODISH:
    z->in.socket = zmq_socket(context, ZMQ_DISH);
    z->out.socket = zmq_socket(context, ZMQ_RADIO);
    break;
#endif

  case zeromq::Pattern::PUBSUB:
    z->in.socket = zmq_socket(context, ZMQ_SUB);
    z->out.socket = zmq_socket(context, ZMQ_PUB);
    break;
  }

  if (!z->in.socket || !z->out.socket) {
    ret = -1;
    goto fail;
  }

  // Join group
  switch (z->pattern) {
#ifdef ZMQ_BUILD_DISH
  case zeromq::Pattern::RADIODISH:
    ret = zmq_join(z->in.socket, z->in.filter);
    break;
#endif

  case zeromq::Pattern::PUBSUB:
    ret = zmq_setsockopt(z->in.socket, ZMQ_SUBSCRIBE, z->in.filter,
                         z->in.filter ? strlen(z->in.filter) : 0);
    break;

  default:
    ret = -1;
  }

  if (ret < 0)
    goto fail;

  if (z->curve.enabled) {
    // Publisher has server role
    ret = zmq_setsockopt(z->out.socket, ZMQ_CURVE_SECRETKEY,
                         z->curve.server.secret_key, 41);
    if (ret)
      goto fail;

    ret = zmq_setsockopt(z->out.socket, ZMQ_CURVE_PUBLICKEY,
                         z->curve.server.public_key, 41);
    if (ret)
      goto fail;

    int curve_server = 1;
    ret = zmq_setsockopt(z->out.socket, ZMQ_CURVE_SERVER, &curve_server,
                         sizeof(curve_server));
    if (ret)
      goto fail;

    // Create temporary client keys first
    ret = zmq_curve_keypair(z->curve.client.public_key,
                            z->curve.client.secret_key);
    if (ret)
      goto fail;

    // Subscriber has client role
    ret = zmq_setsockopt(z->in.socket, ZMQ_CURVE_SECRETKEY,
                         z->curve.client.secret_key, 41);
    if (ret)
      goto fail;

    ret = zmq_setsockopt(z->in.socket, ZMQ_CURVE_PUBLICKEY,
                         z->curve.client.public_key, 41);
    if (ret)
      goto fail;

    ret = zmq_setsockopt(z->in.socket, ZMQ_CURVE_SERVERKEY,
                         z->curve.server.public_key, 41);
    if (ret)
      goto fail;
  }

  for (auto d : dirs) {
    const char *mon_ep =
        d == &z->in ? "inproc://monitor-in" : "inproc://monitor-out";

    ret = zmq_setsockopt(d->socket, ZMQ_IPV6, &z->ipv6, sizeof(z->ipv6));
    if (ret)
      goto fail;

    int linger = 1000;
    ret = zmq_setsockopt(d->socket, ZMQ_LINGER, &linger, sizeof(linger));
    if (ret)
      goto fail;

    // Monitor events on the server
    ret = zmq_socket_monitor(d->socket, mon_ep, ZMQ_EVENT_ALL);
    if (ret < 0)
      goto fail;

    // Create socket for collecting monitor events
    d->mon_socket = zmq_socket(context, ZMQ_PAIR);
    if (!d->mon_socket) {
      ret = -1;
      goto fail;
    }

    // Connect it to the inproc endpoints so they'll get events
    ret = zmq_connect(d->mon_socket, mon_ep);
    if (ret < 0)
      goto fail;

    // Connect / bind sockets to endpoints
    for (size_t i = 0; i < list_length(&d->endpoints); i++) {
      char *ep = (char *)list_at(&d->endpoints, i);

      if (d->bind) {
        ret = zmq_bind(d->socket, ep);
        if (ret < 0)
          goto fail;
      } else {
        ret = zmq_connect(d->socket, ep);
        if (ret < 0)
          goto fail;
      }

      d->pending++;
    }
  }

  // Wait for all connections to be connected
  for (auto d : dirs) {
    while (d->pending > 0) {
      int evt = d->bind ? ZMQ_EVENT_LISTENING : ZMQ_EVENT_CONNECTED;

      ret = get_monitor_event(d->mon_socket, nullptr, nullptr);
      if (ret == evt)
        d->pending--;
    }
  }

#if defined(ZMQ_BUILD_DRAFT_API) && ZMQ_MAJOR_VERSION >= 4 &&                  \
    ZMQ_MINOR_VERSION >= 2 && ZMQ_MINOR_VERSION >= 3
  if (z->curve.enabled) {
    ret = get_monitor_event(z->in.mon_socket, nullptr, nullptr);
    return ret == ZMQ_EVENT_HANDSHAKE_SUCCEEDED;
  } else
    return 0; // The handshake events are only emitted for CurveZMQ sessions
#else
  return 0;
#endif

fail:
  n->logger->info("Failed to start: {}", zmq_strerror(errno));

  return ret;
}

int villas::node::zeromq_stop(NodeCompat *n) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  struct zeromq::Dir *dirs[] = {&z->out, &z->in};

  for (auto d : dirs) {
    ret = zmq_close(d->socket);
    if (ret)
      return ret;

    ret = zmq_close(d->mon_socket);
    if (ret)
      return ret;
  }

  return 0;
}

int villas::node::zeromq_destroy(NodeCompat *n) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  if (z->in.filter)
    free(z->in.filter);

  if (z->out.filter)
    free(z->out.filter);

  ret = list_destroy(&z->out.endpoints, nullptr, true);
  if (ret)
    return ret;

  if (z->formatter)
    delete z->formatter;

  return 0;
}

int villas::node::zeromq_read(NodeCompat *n, struct Sample *const smps[],
                              unsigned cnt) {
  int recv, ret;
  auto *z = n->getData<struct zeromq>();

  zmq_msg_t m;

  ret = zmq_msg_init(&m);
  if (ret < 0)
    return ret;

  if (z->in.filter) {
    switch (z->pattern) {
    case zeromq::Pattern::PUBSUB:
      // Discard envelope
      zmq_recv(z->in.socket, nullptr, 0, 0);
      break;

    default: {
    }
    }
  }

  // Receive payload
  ret = zmq_msg_recv(&m, z->in.socket, 0);
  if (ret < 0)
    return ret;

  recv = z->formatter->sscan((const char *)zmq_msg_data(&m), zmq_msg_size(&m),
                             nullptr, smps, cnt);

  ret = zmq_msg_close(&m);
  if (ret)
    return ret;

  return recv;
}

int villas::node::zeromq_write(NodeCompat *n, struct Sample *const smps[],
                               unsigned cnt) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  size_t wbytes;
  zmq_msg_t m;

  char data[DEFAULT_FORMAT_BUFFER_LENGTH];

  ret = z->formatter->sprint(data, sizeof(data), &wbytes, smps, cnt);
  if (ret <= 0)
    return -1;

  ret = zmq_msg_init_size(&m, wbytes);

  if (z->out.filter) {
    switch (z->pattern) {
#ifdef ZMQ_BUILD_DISH
    case zeromq::Pattern::RADIODISH:
      ret = zmq_msg_set_group(&m, z->out.filter);
      if (ret < 0)
        goto fail;
      break;
#endif

    case zeromq::Pattern::PUBSUB: // Send envelope
      zmq_send(z->out.socket, z->out.filter, strlen(z->out.filter),
               ZMQ_SNDMORE);
      break;
    }
  }

  memcpy(zmq_msg_data(&m), data, wbytes);

  ret = zmq_msg_send(&m, z->out.socket, 0);
  if (ret < 0)
    goto fail;

  ret = zmq_msg_close(&m);
  if (ret < 0)
    return ret;

  return cnt;

fail:
  zmq_msg_close(&m);

  return ret;
}

int villas::node::zeromq_poll_fds(NodeCompat *n, int fds[]) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  int fd;
  size_t len = sizeof(fd);

  ret = zmq_getsockopt(z->in.socket, ZMQ_FD, &fd, &len);
  if (ret)
    return ret;

  fds[0] = fd;

  return 1;
}

int villas::node::zeromq_netem_fds(NodeCompat *n, int fds[]) {
  int ret;
  auto *z = n->getData<struct zeromq>();

  int fd;
  size_t len = sizeof(fd);

  ret = zmq_getsockopt(z->out.socket, ZMQ_FD, &fd, &len);
  if (ret)
    return ret;

  fds[0] = fd;

  return 1;
}

static NodeCompatType p;

__attribute__((constructor(110))) static void register_plugin() {
  p.name = "zeromq";
  p.description = "ZeroMQ Distributed Messaging (libzmq)";
  p.vectorize = 0;
  p.size = sizeof(struct zeromq);
  p.type.start = zeromq_type_start;
  p.type.stop = zeromq_type_stop;
  p.init = zeromq_init;
  p.destroy = zeromq_destroy;
  p.check = zeromq_check;
  p.parse = zeromq_parse;
  p.print = zeromq_print;
  p.start = zeromq_start;
  p.stop = zeromq_stop;
  p.read = zeromq_read;
  p.write = zeromq_write;
  p.reverse = zeromq_reverse;
  p.poll_fds = zeromq_poll_fds;
  p.netem_fds = zeromq_netem_fds;

  static NodeCompatFactory ncp(&p);
}