/*
 *  Copyright (C) 2004-2025 Savoir-faire Linux Inc.
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program. If not, see <https://www.gnu.org/licenses/>.
 */
#include "channel_socket.h"

#include <msgpack.hpp>

namespace dhtnet {

template<typename T>
ChannelSocket::RecvCb
buildMsgpackReader(std::function<std::error_code(T&&)> userCb)
{
    return [cb = std::move(userCb), unpacker = std::make_shared<msgpack::unpacker>()](const uint8_t* buf,
                                                                                      std::size_t len) -> ssize_t {
        unpacker->reserve_buffer(len);
        std::memcpy(unpacker->buffer(), buf, len);
        unpacker->buffer_consumed(len);

        try {
            // Catch msgpack errors to avoid terminating the reader thread
            msgpack::unpacked result;
            while (unpacker->next(result)) {
                if (auto ec = cb(result.get().as<T>()); ec) {
                    return -static_cast<ssize_t>(ec.value());
                }
            }
        } catch (const msgpack::parse_error& e) {
            return -1;
        } catch (const std::bad_cast& e) {
            return -1;
        }
        return static_cast<ssize_t>(len);
    };
}

template<typename T>
class MessageChannel
{
public:
    using RecvCb = std::function<std::error_code(T&&)>;

    MessageChannel(std::shared_ptr<ChannelSocketInterface> channelSocket, RecvCb&& cb, OnShutdownCb&& onShutdown = {})
        : channelSocket_(std::move(channelSocket))
    {
        channelSocket_->setOnRecv(buildMsgpackReader<T>(std::move(cb)));
        if (onShutdown) {
            channelSocket_->onShutdown(std::move(onShutdown));
        }
    }

    std::error_code send(const T& msg)
    {
        msgpack::sbuffer sbuf;
        msgpack::pack(sbuf, msg);
        std::error_code ec;
        auto written = channelSocket_->write(reinterpret_cast<const uint8_t*>(sbuf.data()), sbuf.size(), ec);
        if (written != sbuf.size() && !ec)
            ec = std::make_error_code(std::errc::io_error);
        return ec;
    }

private:
    std::shared_ptr<ChannelSocketInterface> channelSocket_;
};

} // namespace dhtnet
