/*
  _____ _ _ _                    _             _
 |  ___(_) | |_ ___ _ __   _ __ | |_   _  __ _(_)_ __
 | |_  | | | __/ _ \ '__| | '_ \| | | | |/ _` | | '_ \
 |  _| | | | ||  __/ |    | |_) | | |_| | (_| | | | | |
 |_|   |_|_|\__\___|_|    | .__/|_|\__,_|\__, |_|_| |_|
                          |_|            |___/
# A Template for StatsPlugin, a Filter Plugin
# Generated by the command: plugin --type filter --dir . stats
# Hostname: Fram-IV-3.local
# Current working directory: /Users/p4010/Develop/mads_doc/ai_example
# Creation date: 2026-04-10T17:08:34.374+0200
# NOTICE: MADS Version 2.0.4
*/
// Mandatory included headers
#include <filter.hpp>
#include <nlohmann/json.hpp>
#include <pugg/Kernel.h>
#include "test_helpers.hpp"

#include <cmath>
#include <deque>
#include <iostream>
#include <limits>
#include <string>

// Define the name of the plugin
#ifndef PLUGIN_NAME
#define PLUGIN_NAME "stats"
#endif

using namespace std;
using json = nlohmann::json;

/*!
 * @brief Filter plugin that computes running mean and population standard
 * deviation.
 */
class StatsPlugin : public Filter<json, json> {
public:
  /*!
   * @brief Inherit the base filter constructors.
   */
  using Filter::Filter;

  /*!
   * @brief Return the plugin kind used by the MADS plugin loader.
   * @return The plugin name.
   */
  string kind() override { return PLUGIN_NAME; }

  /*!
   * @brief Append incoming scalar values to the running statistics buffer.
   * @param input JSON object containing a numeric `data` array.
   * @param topic Unused topic string required by the filter API.
   * @param blob Unused binary payload pointer required by the API.
   * @return `success` when the stride threshold is met, `retry` when more input
   * is needed, or `error` for schema violations.
   */
  return_type load_data(json const &input, string topic = "",
                        vector<unsigned char> const *blob = nullptr) override {
    (void)topic;
    (void)blob;

    if (!input.contains("data") || !input["data"].is_array()) {
      _error = "Input must contain a numeric data array";
      return return_type::error;
    }

    size_t appended = 0;
    for (const auto &value : input["data"]) {
      if (!value.is_number()) {
        _error = "Input data array must contain only numeric values";
        return return_type::error;
      }
      _buffer.push_back(value.get<double>());
      ++appended;
      while (_buffer.size() > current_window()) {
        _buffer.pop_front();
      }
    }

    _pending_since_success += appended;
    if (_pending_since_success >= current_stride()) {
      _pending_since_success %= current_stride();
      return return_type::success;
    }

    return return_type::retry;
  }

  /*!
   * @brief Compute statistics over the current running window.
   * @param[out] out JSON object receiving `count`, `mean`, and `stddev`.
   * @param[in,out] blob Unused binary payload pointer required by the API.
   * @return `success` when statistics are produced, or `retry` when no values
   * are buffered yet.
   */
  return_type process(json &out,
                      vector<unsigned char> *blob = nullptr) override {
    (void)blob;
    out.clear();

    if (_buffer.empty()) {
      return return_type::retry;
    }

    const double count = static_cast<double>(_buffer.size());
    double sum = 0.0;
    for (double value : _buffer) {
      sum += value;
    }
    const double mean = sum / count;

    double squared_error_sum = 0.0;
    for (double value : _buffer) {
      const double diff = value - mean;
      squared_error_sum += diff * diff;
    }

    out["count"] = _buffer.size();
    out["mean"] = mean;
    out["stddev"] = sqrt(squared_error_sum / count);
    if (!_agent_id.empty()) {
      out["agent_id"] = _agent_id;
    }
    return return_type::success;
  }

  /*!
   * @brief Configure the running window and stride parameters.
   * @param params JSON configuration containing `window` and `stride`.
   */
  void set_params(const json &params) override {
    Filter::set_params(params);
    _params = json::object();
    _params["window"] = 100u;
    _params["stride"] = 50u;
    _params.merge_patch(params);

    const unsigned int window = _params.value("window", 100u);
    unsigned int stride = _params.value("stride", window / 2u);
    if (window == 0u) {
      _params["window"] = 1u;
    }
    if (stride == 0u) {
      stride = 1u;
    }
    _params["stride"] = stride;

    while (_buffer.size() > current_window()) {
      _buffer.pop_front();
    }
    _pending_since_success = 0;
  }

  /*!
   * @brief Report the current filter configuration.
   * @return Map containing the configured window width and stride.
   */
  map<string, string> info() override {
    return {
        {"window", to_string(current_window())},
        {"stride", to_string(current_stride())},
    };
  }

private:
  size_t current_window() const {
    return static_cast<size_t>(max(1u, _params.value("window", 100u)));
  }

  size_t current_stride() const {
    return static_cast<size_t>(max(1u, _params.value("stride", 50u)));
  }

  deque<double> _buffer;
  size_t _pending_since_success = 0;
};

INSTALL_FILTER_DRIVER(StatsPlugin, json, json);

int main(int argc, char const *argv[]) {
  (void)argc;
  (void)argv;

  const bool ok = test_helpers::run_stats_tests<StatsPlugin>();
  if (!ok) {
    return 1;
  }

  cout << "stats tests passed" << endl;
  return 0;
}
