/**
 * SMIoT JASON Qualilife sound recording class.
 *
 * Author: Jan Schlüter <jan.schluter@lis-lab.fr>
 * Author: Maxence Ferrari <maxence.ferrari@lis-lab.fr>
 */

#include "recorder.h"
#include <stdexcept>
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>

JasonRecorder::JasonRecorder(bool verbose) : verbose(verbose) {
    // create libusb context
    if (libusb_init(&ctx) < 0) {
        throw std::runtime_error("libusb initialization failed");
    }

    // set debug level
    libusb_set_debug(ctx, 3);

    // discover JASON sound cards
    libusb_device** all_devices;
    if (libusb_get_device_list(ctx, &all_devices) < 0) {
        throw std::runtime_error("libusb device enumeration failed");
    }
    libusb_device** device = all_devices;
    while (*device != NULL) {
        struct libusb_device_descriptor desc;
        if (libusb_get_device_descriptor(*device, &desc) < 0) {
            continue;
        }
        if ((desc.idVendor == VENDOR_ID) && (desc.idProduct == PRODUCT_ID)) {
            devices.push_back(*device);
        }
        else {
            libusb_unref_device(*device);
        }
        device++;
    }
    libusb_free_device_list(all_devices, 0);
}

JasonRecorder::~JasonRecorder() {
    // free handle
    if (handle) {
        libusb_release_interface(handle, 0);
        libusb_close(handle);
    }
    // free devices
    for (auto& device : devices) {
        libusb_unref_device(device);
    }
    // free libusb libusb context
    libusb_exit(ctx);
}

size_t JasonRecorder::get_device_count() {
    return devices.size();
}

void JasonRecorder::set_device(size_t number) {
    if (handle) {
        libusb_release_interface(handle, 0);
        libusb_close(handle);
        handle = NULL;
    }
    if (number >= devices.size()) {
        throw std::out_of_range("device number too large");
    }
    if (libusb_open(devices[number], &handle) < 0) {
        throw std::runtime_error("could not open USB device (try again as root)");
    }
    if (libusb_claim_interface(handle, 0) < 0) {
        throw std::runtime_error("could not claim USB interface");
    }
}

void JasonRecorder::send_message(std::uint16_t cmd) {
    send_message(cmd, NULL, 0);
}

void JasonRecorder::send_message(std::uint16_t cmd, std::vector<std::uint8_t> &payload) {
    send_message(cmd, payload.data(), payload.size());
}

void JasonRecorder::send_message(std::uint16_t cmd, std::uint8_t *payload, size_t length) {
    if (!handle) {
        throw std::logic_error("must call set_device() first");
    }
    // message format: 0xfe + payload size (2 byte) + command (1 byte) + payload
    std::vector<std::uint8_t> data;
    data.reserve(6 + length);
    data.push_back(FRAME_START);
    data.push_back((std::uint8_t) ((cmd >> 8) & 0xFF));
    data.push_back((std::uint8_t) (cmd & 0xFF));
    data.push_back((std::uint8_t) ((length >> 8) & 0xFF));
    data.push_back((std::uint8_t) (length & 0xFF));
    if (length) {
        data.insert(data.end(), payload, payload + length);
    }
    // compute the checksum
    data.push_back(FRAME_START);
    for (int i=1; i < 5+length; data[5 + length] ^= data[i++]);
    // send message, allow a maximum of 10 seconds for it to go through
    int sent;
    if (libusb_bulk_transfer(handle, ENDPOINT_SEND, data.data(), data.size(), &sent, 10000) < 0) {
        throw std::runtime_error("could not send message to device");
    }
    else if (sent != data.size()) {
        throw std::runtime_error("could not send complete message to device");
    };
}

void JasonRecorder::start_recording(std::uint8_t num_channels,size_t  sample_rate, std::uint8_t depth, std::uint8_t num_filter) {
    std::vector<std::uint8_t> payload1 = {
            START,
            (std::uint8_t) ((sample_rate >> 24) & 0xFF),
            (std::uint8_t) ((sample_rate >> 16) & 0xFF),
            (std::uint8_t) ((sample_rate >> 8) & 0xFF),
            (std::uint8_t) (sample_rate & 0xFF),
            num_channels,
            (std::uint8_t) (8 * depth),
            num_filter};
    send_message(START_ID, payload1);
    this->num_channels = num_channels;
    this->sample_rate = sample_rate;
    this->depth = depth;
    this->num_filter = num_filter;
    recording = true;
}


void JasonRecorder::stop_recording() {
    std::vector<std::uint8_t> payload1 = {
            STOP,
            (std::uint8_t) ((this->sample_rate >> 24) & 0xFF),
            (std::uint8_t) ((this->sample_rate >> 16) & 0xFF),
            (std::uint8_t) ((this->sample_rate >> 8) & 0xFF),
            (std::uint8_t) (this->sample_rate & 0xFF),
            this->num_channels,
            (std::uint8_t) (8 * this->depth),
            this->num_filter};
    send_message(START_ID, payload1);
    recording = false;
}

size_t JasonRecorder::receive_message(uint8_t *buffer, size_t max_wait) {
    if (!handle) {
        throw std::logic_error("must call set_device() first");
    }
    int received;
    int status = libusb_bulk_transfer(handle, ENDPOINT_RECEIVE, buffer, MAX_MSG_LENGTH, &received, max_wait);
    if (status == LIBUSB_ERROR_OVERFLOW) {
        throw std::runtime_error("buffer too small to receive message from device");
    }
    else if ((status < 0) && (status != LIBUSB_ERROR_TIMEOUT)) {
        throw std::runtime_error("could not receive message from device");
    }
    return received;
}

void JasonRecorder::get_samples(std::vector<std::uint8_t> &samples, std::vector<std::uint8_t> &imu_data, bool planar, size_t max_wait) {
    if (!num_channels || !sample_rate) {
        throw std::logic_error("must call set_format() first");
    }
    std::array<std::uint8_t, MAX_MSG_LENGTH> buffer{};
    while (true) {
        size_t received = receive_message(buffer.data(), max_wait);
        if (received) {
            // we could read the payload length, but it is wrong for sample data
            //size_t length = buffer[1] << 8 + buffer[2];
            if (buffer[0] != FRAME_START); // invalid message
            else if ((((std::uint16_t) buffer[1] << 8 )|(buffer[2])) == DATA_ID) {
                // find the beginning and length of the samples in the buffer
                size_t start = this->additional_data_size + 6;
                imu_data.resize(0);
                imu_data.reserve(this->additional_data_size);
                imu_data.insert(imu_data.begin(), &buffer[6], &buffer[start]);
                size_t num_samples = (received - start);
                num_samples = (num_samples / (num_channels * this->depth)) * num_channels * this->depth;
                // copy data to provided vector
                if (planar || (num_channels == 1)) {
                    // copy out directly
                    samples.resize(0);
                    samples.reserve(num_samples);
                    samples.insert(samples.end(), &buffer[start], &buffer[start] + num_samples);
                }
                else {
                    // convert from blocked channels to interleaved channels
                    samples.resize(num_samples);
                    JasonRecorder::interleave_channels(&buffer[start],
                                                       samples.data(), num_samples,
                                                       this->num_channels, this->depth);
                }
                break;
            }
            else if (this->verbose && (((std::uint16_t) buffer[1] << 8 )|(buffer[2])) == STATUS_ID) {
                samples.resize(0);
                std::uint8_t cks=FRAME_START; //buffer[0] == FRAME_START already check
                for (int i=1; i <  31; cks ^= buffer[i++]);
                std::cout << " Sr: " << (  ((size_t) buffer[5] << 24) | ((size_t) buffer[6] << 16)
                                         | ((size_t) buffer[7] <<  8) | ((size_t)  buffer[8]))
                          << " #Ch: " << (size_t) buffer[9] << " D: " << (size_t) buffer[10] <<  " Time: "
                          << 2000 + buffer[11] <<'-'<< (size_t) buffer[12] <<'-'<< (size_t) buffer[13] <<' '
                          << (size_t) buffer[14] <<':'<< (size_t) buffer[15] <<':'<< (size_t) buffer[16]
                          << " UUID: "  << std::hex << (size_t) buffer[17] << (size_t) buffer[18] << (size_t) buffer[19] << (size_t) buffer[20]
                          << (size_t) buffer[21] << (size_t) buffer[22] << (size_t) buffer[23] << (size_t) buffer[24] << std::dec
                          << " Rec: " << (buffer[25] !=0)
                          << " SPI: " << (size_t) buffer[26] << (size_t) buffer[27] << (size_t) buffer[28] << (size_t) buffer[29]
                          << " CKS: " << (cks == 0?"True":"False") << std::endl;
                break;
            }
        }
        else if (max_wait > 0) {
            // we timed out, we do not want to wait again
            samples.resize(0);
            break;
        }
    }
}

void JasonRecorder::interleave_channels(std::uint8_t *input, std::uint8_t *output, size_t num_bytes,
                                        size_t num_channels, size_t depth) {
    // the input comes in num_channels blocks of num_samples_per_channel little-endian 16-bit samples each
    // we write these to the output in a round-robin manner, interleaving the channels
    // we use a pattern that accesses the output strictly sequentially, so it can be used to write to a mem-mapped file
    if ((num_channels < 1) || (num_channels > MAX_CHANNELS)) {
        throw std::out_of_range("num_channels must be in [1, 6]");
    }
    // prepare one input pointer per channel
    std::uint8_t *inputs[num_channels];
    for (size_t c = 0; c < num_channels; c++) {
        inputs[c] = input + c * (num_bytes/num_channels);
    }
    // iterate over the samples, copying in interleaved fashion
    size_t c = 0;
    for (size_t b=0; b < num_bytes;) {
        *(output++) = *(inputs[c]++);
        if (++b % depth == 0)
            c = (c + 1) % num_channels;
    }
}