/* Lua expressions hook.
 *
 * Author: Steffen Vogel <post@steffenvogel.de>
 * SPDX-FileCopyrightText: 2014-2023 Institute for Automation of Complex Power Systems, RWTH Aachen University
 * SPDX-License-Identifier: Apache-2.0
 */

#include <cstdio>
#include <map>
#include <vector>

extern "C" {
#include <lauxlib.h>
#include <lua.h>
#include <lualib.h>
};

#include <villas/exceptions.hpp>
#include <villas/hooks/lua.hpp>
#include <villas/node.hpp>
#include <villas/path.hpp>
#include <villas/sample.hpp>
#include <villas/signal.hpp>
#include <villas/signal_list.hpp>
#include <villas/utils.hpp>

using namespace villas;
using namespace villas::node;

class LuaError : public RuntimeError {

protected:
  lua_State *L;
  int err;

public:
  LuaError(lua_State *l, int e) : RuntimeError(""), L(l), err(e) {}

  virtual const char *what() const noexcept {
    const char *msg;
    switch (err) {
    case LUA_ERRSYNTAX:
      msg = "Syntax error";
      break;

    case LUA_ERRMEM:
      msg = "Memory allocation error";
      break;

    case LUA_ERRFILE:
      msg = "Failed to open Lua script";
      break;

    case LUA_ERRRUN:
      msg = "Runtime error";
      break;

    case LUA_ERRERR:
      msg = "Failed to call error handler";
      break;

    default:
      msg = "Unknown error";
      break;
    }

    char *buf;

    if (asprintf(&buf, "Lua: %s: %s", msg, lua_tostring(L, -1)) < 0)
      return "Lua: could not format error message";

    return buf;
  }
};

static void lua_pushtimespec(lua_State *L, struct timespec *ts) {
  lua_createtable(L, 2, 0);

  lua_pushnumber(L, ts->tv_sec);
  lua_rawseti(L, -2, 0);

  lua_pushnumber(L, ts->tv_nsec);
  lua_rawseti(L, -2, 1);
}

static void lua_totimespec(lua_State *L, struct timespec *ts) {
  lua_rawgeti(L, -1, 0);
  ts->tv_sec = lua_tonumber(L, -1);

  lua_rawgeti(L, -2, 1);
  ts->tv_nsec = lua_tonumber(L, -1);

  lua_pop(L, 2);
}

static bool lua_pushsignaldata(lua_State *L, const union SignalData *data,
                               const Signal::Ptr sig) {
  switch (sig->type) {
  case SignalType::FLOAT:
    lua_pushnumber(L, data->f);
    break;

  case SignalType::INTEGER:
    lua_pushinteger(L, data->i);
    break;

  case SignalType::BOOLEAN:
    lua_pushboolean(L, data->b);
    break;

  case SignalType::COMPLEX:
  case SignalType::INVALID:
  default:
    return false; // we skip unknown types. Lua will see a nil value in the table
  }

  return true;
}

static void lua_tosignaldata(lua_State *L, union SignalData *data,
                             enum SignalType targetType, int idx = -1) {
  int luaType;
  enum SignalType type;

  luaType = lua_type(L, idx);
  switch (luaType) {
  case LUA_TBOOLEAN:
    data->b = lua_toboolean(L, idx);
    type = SignalType::BOOLEAN;
    break;

  case LUA_TNUMBER:
    data->f = lua_tonumber(L, idx);
    type = SignalType::FLOAT;
    break;

  default:
    return;
  }

  *data = data->cast(type, targetType);
}

static void lua_tosample(lua_State *L, struct Sample *smp,
                         SignalList::Ptr signals, bool use_names = true,
                         int idx = -1) {
  int ret;

  smp->length = 0;
  smp->flags = 0;

  lua_getfield(L, idx, "sequence");
  ret = lua_type(L, -1);
  if (ret != LUA_TNIL) {
    smp->sequence = lua_tonumber(L, -1);
    smp->flags |= (int)SampleFlags::HAS_SEQUENCE;
  }
  lua_pop(L, 1);

  lua_getfield(L, idx, "ts_origin");
  ret = lua_type(L, -1);
  if (ret != LUA_TNIL) {
    lua_totimespec(L, &smp->ts.origin);
    smp->flags |= (int)SampleFlags::HAS_TS_ORIGIN;
  }
  lua_pop(L, 1);

  lua_getfield(L, idx, "ts_received");
  ret = lua_type(L, -1);
  if (ret != LUA_TNIL) {
    lua_totimespec(L, &smp->ts.received);
    smp->flags |= (int)SampleFlags::HAS_TS_RECEIVED;
  }
  lua_pop(L, 1);

  lua_getfield(L, idx, "data");
  ret = lua_type(L, -1);
  if (ret != LUA_TNIL) {
    int i = 0;
    for (auto sig : *signals) {
      if (use_names)
        lua_getfield(L, -1, sig->name.c_str());
      else
        lua_rawgeti(L, -1, i);

      ret = lua_type(L, -1);
      if (ret != LUA_TNIL)
        lua_tosignaldata(L, &smp->data[i], sig->type, -1);
      else
        smp->data[i] = sig->init;

      lua_pop(L, 1);
      i++;
      smp->length++;
    }

    if (smp->length > 0)
      smp->flags |= (int)SampleFlags::HAS_DATA;
  }
  lua_pop(L, 1);
}

static void lua_pushsample(lua_State *L, struct Sample *smp,
                           bool use_names = true) {
  lua_createtable(L, 0, 5);

  lua_pushnumber(L, smp->flags);
  lua_setfield(L, -2, "flags");

  if (smp->flags & (int)SampleFlags::HAS_SEQUENCE) {
    lua_pushnumber(L, smp->sequence);
    lua_setfield(L, -2, "sequence");
  }

  if (smp->flags & (int)SampleFlags::HAS_TS_ORIGIN) {
    lua_pushtimespec(L, &smp->ts.origin);
    lua_setfield(L, -2, "ts_origin");
  }

  if (smp->flags & (int)SampleFlags::HAS_TS_RECEIVED) {
    lua_pushtimespec(L, &smp->ts.received);
    lua_setfield(L, -2, "ts_received");
  }

  if (smp->flags & (int)SampleFlags::HAS_DATA) {
    lua_createtable(L, smp->length, 0);

    for (unsigned i = 0; i < smp->length; i++) {
      const auto sig = smp->signals->getByIndex(i);
      const auto *data = &smp->data[i];

      auto pushed = lua_pushsignaldata(L, data, sig);
      if (!pushed)
        continue;

      if (use_names)
        lua_setfield(L, -2, sig->name.c_str());
      else
        lua_rawseti(L, -2, i);
    }

    lua_setfield(L, -2, "data");
  }
}

static void lua_pushjson(lua_State *L, json_t *json) {
  size_t i;
  const char *key;
  json_t *json_value;

  switch (json_typeof(json)) {
  case JSON_OBJECT:
    lua_newtable(L);
    json_object_foreach(json, key, json_value) {
      lua_pushjson(L, json_value);
      lua_setfield(L, -2, key);
    }
    break;

  case JSON_ARRAY:
    lua_newtable(L);
    json_array_foreach(json, i, json_value) {
      lua_pushjson(L, json_value);
      lua_rawseti(L, -2, i);
    }
    break;

  case JSON_STRING:
    lua_pushstring(L, json_string_value(json));
    break;

  case JSON_REAL:
  case JSON_INTEGER:
    lua_pushnumber(L, json_integer_value(json));
    break;

  case JSON_TRUE:
  case JSON_FALSE:
    lua_pushboolean(L, json_boolean_value(json));
    break;

  case JSON_NULL:
    lua_pushnil(L);
    break;
  }
}

static json_t *lua_tojson(lua_State *L, int index = -1) {
  double n;
  const char *s;
  bool b;

  switch (lua_type(L, index)) {
  case LUA_TFUNCTION:
  case LUA_TUSERDATA:
  case LUA_TTHREAD:
  case LUA_TLIGHTUSERDATA:
  case LUA_TNIL:
    return json_null();

  case LUA_TNUMBER:
    n = lua_tonumber(L, index);
    return n == (int)n ? json_integer(n) : json_real(n);

  case LUA_TBOOLEAN:
    b = lua_toboolean(L, index);
    return json_boolean(b);

  case LUA_TSTRING:
    s = lua_tostring(L, index);
    return json_string(s);

  case LUA_TTABLE: {
    int keys_total = 0, keys_int = 0, key_highest = -1;

    lua_pushnil(L);
    while (lua_next(L, index) != 0) {
      keys_total++;
      if (lua_type(L, -2) == LUA_TNUMBER) {
        int key = lua_tonumber(L, -1);

        if (key == (int)key) {
          keys_int++;
          if (key > key_highest)
            key_highest = key;
        }
      }
      lua_pop(L, 1);
    }

    bool is_array = keys_total == keys_int && key_highest / keys_int > 0.5;

    json_t *json = is_array ? json_array() : json_object();

    lua_pushnil(L);
    while (lua_next(L, index) != 0) {
      json_t *val = lua_tojson(L, -1);
      if (is_array) {
        int key = lua_tonumber(L, -2);
        json_array_set(json, key, val);
      } else {
        const char *key = lua_tostring(L, -2);
        if (key) // Skip table entries whose keys are neither string or number!
          json_object_set(json, key, val);
      }
      lua_pop(L, 1);
    }

    return json;
  }
  }

  return nullptr;
}

namespace villas {
namespace node {

LuaSignalExpression::LuaSignalExpression(lua_State *l, json_t *json_sig)
    : cookie(0), L(l) {
  int ret;

  json_error_t err;
  const char *expr;

  // Parse expression
  ret = json_unpack_ex(json_sig, &err, 0, "{ s: s }", "expression", &expr);
  if (ret)
    throw ConfigError(json_sig, err, "node-config-hook-lua-signals");

  cfg = json_sig;

  expression = expr;
}

void LuaSignalExpression::prepare() { parseExpression(expression); }

void LuaSignalExpression::parseExpression(const std::string &expr) {
  // Release previous expression
  if (cookie)
    luaL_unref(L, LUA_REGISTRYINDEX, cookie);

  auto fexpr = fmt::format("return {}", expr);

  int err = luaL_loadstring(L, fexpr.c_str());
  if (err)
    throw ConfigError(cfg, "node-config-hook-lua-signals",
                      "Failed to load Lua expression: {}", lua_tostring(L, -1));

  cookie = luaL_ref(L, LUA_REGISTRYINDEX);
}

void LuaSignalExpression::evaluate(union SignalData *data,
                                   enum SignalType type) {
  int err;

  lua_rawgeti(L, LUA_REGISTRYINDEX, cookie);

  err = lua_pcall(L, 0, 1, 0);
  if (err) {
    throw RuntimeError("Lua: Evaluation failed: {}", lua_tostring(L, -1));
    lua_pop(L, 1);
  }

  lua_tosignaldata(L, data, type, -1);

  lua_pop(L, 1);
}

LuaHook::LuaHook(Path *p, Node *n, int fl, int prio, bool en)
    : Hook(p, n, fl, prio, en),
      signalsExpressions(std::make_shared<SignalList>()), L(luaL_newstate()),
      useNames(true), hasExpressions(false), needsLocking(false),
      functions({0}) {}

LuaHook::~LuaHook() { lua_close(L); }

void LuaHook::parseExpressions(json_t *json_sigs) {
  int ret;
  size_t i;
  json_t *json_sig;

  signalsExpressions->clear();
  ret = signalsExpressions->parse(json_sigs);
  if (ret)
    throw ConfigError(json_sigs, "node-config-hook-lua-signals",
                      "Setting 'signals' must be a list of dicts");

  // cppcheck-suppress unknownMacro
  json_array_foreach(json_sigs, i, json_sig)
      expressions.emplace_back(L, json_sig);

  hasExpressions = true;
}

void LuaHook::parse(json_t *json) {
  int ret;
  const char *script_str = nullptr;
  int names = 1;
  json_error_t err;
  json_t *json_signals = nullptr;

  assert(state != State::STARTED);

  Hook::parse(json);

  ret = json_unpack_ex(json, &err, 0, "{ s?: s, s?: o, s?: b }", "script",
                       &script_str, "signals", &json_signals, "use_names",
                       &names);
  if (ret)
    throw ConfigError(json, err, "node-config-hook-lua");

  useNames = names;

  if (script_str)
    script = script_str;

  if (json_signals)
    parseExpressions(json_signals);

  state = State::PARSED;
}

void LuaHook::lookupFunctions() {
  int ret;

  std::map<const char *, int *> funcs = {
      {"start", &functions.start},       {"stop", &functions.stop},
      {"restart", &functions.restart},   {"prepare", &functions.prepare},
      {"periodic", &functions.periodic}, {"process", &functions.process}};

  for (auto it : funcs) {
    lua_getglobal(L, it.first);

    ret = lua_type(L, -1);
    if (ret == LUA_TFUNCTION) {
      logger->debug("Found Lua function: {}()", it.first);
      *(it.second) = lua_gettop(L);
    } else {
      *(it.second) = 0;
      lua_pop(L, 1);
    }
  }
}

void LuaHook::loadScript() {
  int ret;

  if (script.empty())
    return; // No script given

  ret = luaL_loadfile(L, script.c_str());
  if (ret)
    throw LuaError(L, ret);

  ret = lua_pcall(L, 0, LUA_MULTRET, 0);
  if (ret)
    throw LuaError(L, ret);
}

int LuaHook::luaRegisterApiHandler(lua_State *L) {
  // register_api_handler(path_regex)
  return 0;
}

int LuaHook::luaInfo(lua_State *L) {
  logger->info(luaL_checkstring(L, 1));
  return 0;
}

int LuaHook::luaWarn(lua_State *L) {
  logger->warn(luaL_checkstring(L, 1));
  return 0;
}

int LuaHook::luaError(lua_State *L) {
  logger->error(luaL_checkstring(L, 1));
  return 0;
}

int LuaHook::luaDebug(lua_State *L) {
  logger->debug(luaL_checkstring(L, 1));
  return 0;
}

void LuaHook::setupEnvironment() {
  lua_pushlightuserdata(L, this);
  lua_rawseti(L, LUA_REGISTRYINDEX, SELF_REFERENCE);

  lua_register(L, "info", &dispatch<&LuaHook::luaInfo>);
  lua_register(L, "warn", &dispatch<&LuaHook::luaWarn>);
  lua_register(L, "error", &dispatch<&LuaHook::luaError>);
  lua_register(L, "debug", &dispatch<&LuaHook::luaDebug>);

  lua_register(L, "register_api_handler",
               &dispatch<&LuaHook::luaRegisterApiHandler>);
}

void LuaHook::prepare() {
  // Load Lua standard libraries
  luaL_openlibs(L);

  // Load our Lua script
  logger->debug("Loading Lua script: {}", script);

  setupEnvironment();
  loadScript();
  lookupFunctions();

  /* Check if we need to protect the Lua state with a mutex
   * This is the case if we have a periodic callback defined
   * As periodic() gets called from the main thread
   */
  needsLocking = functions.periodic > 0;

  // Prepare Lua process()
  if (functions.process) {
    /* We currently do not support the alteration of
     * signal metadata in process() */
    signalsProcessed = signals;
  }

  // Prepare Lua expressions
  if (hasExpressions) {
    for (auto &expr : expressions)
      expr.prepare();

    signals = signalsExpressions;
  }

  if (!functions.process && !hasExpressions)
    logger->warn(
        "The hook has neither a script or expressions defined. It is a no-op!");

  if (functions.prepare) {
    auto lockScope = needsLocking ? std::unique_lock<std::mutex>(mutex)
                                  : std::unique_lock<std::mutex>();

    logger->debug("Executing Lua function: prepare()");
    lua_pushvalue(L, functions.prepare);
    lua_pushjson(L, config);
    int ret = lua_pcall(L, 1, 0, 0);
    if (ret)
      throw LuaError(L, ret);
  }
}

void LuaHook::start() {
  assert(state == State::PREPARED);

  auto lockScope = needsLocking ? std::unique_lock<std::mutex>(mutex)
                                : std::unique_lock<std::mutex>();

  if (functions.start) {
    logger->debug("Executing Lua function: start()");
    lua_pushvalue(L, functions.start);
    int ret = lua_pcall(L, 0, 0, 0);
    if (ret)
      throw LuaError(L, ret);
  }

  state = State::STARTED;
}

void LuaHook::stop() {
  assert(state == State::STARTED);

  auto lockScope = needsLocking ? std::unique_lock<std::mutex>(mutex)
                                : std::unique_lock<std::mutex>();

  if (functions.stop) {
    logger->debug("Executing Lua function: stop()");
    lua_pushvalue(L, functions.stop);
    int ret = lua_pcall(L, 0, 0, 0);
    if (ret)
      throw LuaError(L, ret);
  }

  state = State::STOPPED;
}

void LuaHook::restart() {
  auto lockScope = needsLocking ? std::unique_lock<std::mutex>(mutex)
                                : std::unique_lock<std::mutex>();

  assert(state == State::STARTED);

  if (functions.restart) {
    logger->debug("Executing Lua function: restart()");
    lua_pushvalue(L, functions.restart);
    int ret = lua_pcall(L, 0, 0, 0);
    if (ret)
      throw LuaError(L, ret);
  } else
    Hook::restart();
}

void LuaHook::periodic() {
  assert(state == State::STARTED);

  if (functions.periodic) {
    auto lockScope = needsLocking ? std::unique_lock<std::mutex>(mutex)
                                  : std::unique_lock<std::mutex>();

    logger->debug("Executing Lua function: restart()");
    lua_pushvalue(L, functions.periodic);
    int ret = lua_pcall(L, 0, 0, 0);
    if (ret)
      throw LuaError(L, ret);
  }
}

Hook::Reason LuaHook::process(struct Sample *smp) {
  if (!functions.process && !hasExpressions)
    return Reason::OK;

  int rtype;
  enum Reason reason;
  auto lockScope = needsLocking ? std::unique_lock<std::mutex>(mutex)
                                : std::unique_lock<std::mutex>();

  // First, run the process() function of the script
  if (functions.process) {
    logger->debug("Executing Lua function: process(smp)");

    lua_pushsample(L, smp, useNames);

    lua_pushvalue(L, functions.process);
    lua_pushvalue(L, -2); // Push a copy since lua_pcall() will pop it
    int ret = lua_pcall(L, 1, 1, 0);
    if (ret)
      throw LuaError(L, ret);

    rtype = lua_type(L, -1);
    if (rtype == LUA_TNUMBER) {
      reason = (Reason)lua_tonumber(L, -1);
    } else {
      logger->warn(
          "Lua process() did not return a valid number. Assuming Reason::OK");
      reason = Reason::OK;
    }

    lua_pop(L, 1);

    lua_tosample(L, smp, signalsProcessed, useNames);
  } else
    reason = Reason::OK;

  // After that evaluate expressions
  if (hasExpressions) {
    lua_pushsample(L, smp, useNames);
    lua_setglobal(L, "smp");

    for (unsigned i = 0; i < expressions.size(); i++) {
      auto sig = signalsExpressions->getByIndex(i);
      if (!sig)
        continue;

      expressions[i].evaluate(&smp->data[i], sig->type);
    }

    smp->length = expressions.size();
  }

  return reason;
}

// Register hook
static char n[] = "lua";
static char d[] = "Implement hook functions or expressions in Lua";
static HookPlugin<LuaHook, n, d,
                  (int)Hook::Flags::NODE_READ | (int)Hook::Flags::NODE_WRITE |
                      (int)Hook::Flags::PATH,
                  1>
    p;

} // namespace node
} // namespace villas