#nullable enable
using System;
using System.Collections.Generic;
using IGP.UnitySDK.Models;
using UnityEngine;

namespace IGP.UnitySDK.Network
{
    internal enum IGPOutgoingPriority
    {
        Control = 0,
        ReliableData = 1
    }

    internal sealed class IGPOutgoingReliableScheduler
    {
        private const int DefaultOutgoingBytesPerSecond = 48 * 1024;
        private const int DefaultMaxPayloadsPerDrain = 2;
        private const int DefaultWaitSndHighWatermark = 256;
        private const int DefaultWaitSndLowWatermark = 64;
        private const int DefaultMaxQueuedPayloads = 1024;
        private const long DefaultMaxQueuedBytesTotal = 8L * 1024L * 1024L;
        private const long DefaultMaxQueuedBytesPerPeer = 4L * 1024L * 1024L;

        private readonly object sync = new object();
        private readonly Dictionary<string, LinkedList<QueueItem>> queues =
            new Dictionary<string, LinkedList<QueueItem>>(StringComparer.Ordinal);
        private readonly Dictionary<string, long> queuedBytesByPeer =
            new Dictionary<string, long>(StringComparer.Ordinal);
        private readonly Queue<string> activePeers = new Queue<string>();
        private readonly HashSet<string> activePeerSet = new HashSet<string>(StringComparer.Ordinal);
        private readonly Func<float> timeProvider;

        private float tokenBucketBytes;
        private float lastTokenTime = -1f;
        private int queuedPayloads;
        private long queuedBytes;
        private bool throttledByWaitSnd;
        private string throttleCode = "none";
        private int lastWaitSnd = -1;
        private long generation;

        private long totalDroppedPayloads;
        private long totalDroppedBytes;
        private long intervalEnqueuedPayloads;
        private long intervalEnqueuedBytes;
        private long intervalSentPayloads;
        private long intervalSentBytes;
        private long intervalDroppedPayloads;
        private long intervalDroppedBytes;
        private long intervalSendFailures;

        public IGPOutgoingReliableScheduler()
            : this(() => Time.realtimeSinceStartup)
        {
        }

        internal IGPOutgoingReliableScheduler(Func<float> timeProvider)
        {
            this.timeProvider = timeProvider ?? throw new ArgumentNullException(nameof(timeProvider));
            tokenBucketBytes = DefaultOutgoingBytesPerSecond;
        }

        public int OutgoingBytesPerSecond { get; set; } = DefaultOutgoingBytesPerSecond;

        public int MaxPayloadsPerDrain { get; set; } = DefaultMaxPayloadsPerDrain;

        public int WaitSndHighWatermark { get; set; } = DefaultWaitSndHighWatermark;

        public int WaitSndLowWatermark { get; set; } = DefaultWaitSndLowWatermark;

        public int MaxQueuedPayloads { get; set; } = DefaultMaxQueuedPayloads;

        public long MaxQueuedBytesTotal { get; set; } = DefaultMaxQueuedBytesTotal;

        public long MaxQueuedBytesPerPeer { get; set; } = DefaultMaxQueuedBytesPerPeer;

        public int QueuedPayloadCount
        {
            get
            {
                lock (sync)
                {
                    return queuedPayloads;
                }
            }
        }

        public long QueuedBytes
        {
            get
            {
                lock (sync)
                {
                    return queuedBytes;
                }
            }
        }

        public long TotalDroppedPayloads
        {
            get
            {
                lock (sync)
                {
                    return totalDroppedPayloads;
                }
            }
        }

        public long TotalDroppedBytes
        {
            get
            {
                lock (sync)
                {
                    return totalDroppedBytes;
                }
            }
        }

        public bool IsThrottled
        {
            get
            {
                lock (sync)
                {
                    return throttledByWaitSnd;
                }
            }
        }

        public string ThrottleCode
        {
            get
            {
                lock (sync)
                {
                    return throttleCode;
                }
            }
        }

        public bool Enqueue(
            string remotePlayerId,
            Message message,
            int estimatedBytes,
            IGPOutgoingPriority priority = IGPOutgoingPriority.ReliableData)
        {
            var payload = new IGPOutgoingReliablePayload(message, estimatedBytes, priority);
            return EnqueueBatch(remotePlayerId, new[] { payload });
        }

        public bool EnqueueBatch(string remotePlayerId, IReadOnlyList<IGPOutgoingReliablePayload> payloads)
        {
            if (string.IsNullOrWhiteSpace(remotePlayerId) || payloads == null || payloads.Count == 0)
            {
                return false;
            }

            long batchBytes = 0;
            for (int i = 0; i < payloads.Count; i++)
            {
                if (payloads[i].Message == null)
                {
                    return false;
                }

                batchBytes += Math.Max(0, payloads[i].EstimatedBytes);
            }

            lock (sync)
            {
                long peerBytes = queuedBytesByPeer.TryGetValue(remotePlayerId, out var existingPeerBytes)
                    ? existingPeerBytes
                    : 0;

                if (queuedPayloads + payloads.Count > MaxQueuedPayloads ||
                    queuedBytes + batchBytes > MaxQueuedBytesTotal ||
                    peerBytes + batchBytes > MaxQueuedBytesPerPeer)
                {
                    RecordDrop(payloads.Count, batchBytes);
                    return false;
                }

                if (!queues.TryGetValue(remotePlayerId, out var queue))
                {
                    queue = new LinkedList<QueueItem>();
                    queues[remotePlayerId] = queue;
                    queuedBytesByPeer[remotePlayerId] = 0;
                    AddActivePeer(remotePlayerId);
                }

                for (int i = 0; i < payloads.Count; i++)
                {
                    var payload = payloads[i];
                    int bytes = Math.Max(0, payload.EstimatedBytes);
                    queue.AddLast(new QueueItem(remotePlayerId, payload.Message!, bytes, payload.Priority));
                    queuedPayloads++;
                    queuedBytes += bytes;
                    queuedBytesByPeer[remotePlayerId] += bytes;
                    intervalEnqueuedPayloads++;
                    intervalEnqueuedBytes += bytes;
                }

                return true;
            }
        }

        public IGPOutgoingReliableDrainResult Drain(int currentWaitSnd, Func<Message, bool> send)
        {
            if (send == null)
            {
                throw new ArgumentNullException(nameof(send));
            }

            int sentPayloads = 0;
            long sentBytes = 0;
            int failures = 0;

            while (sentPayloads < Math.Max(1, MaxPayloadsPerDrain))
            {
                if (!TryDequeueReady(currentWaitSnd, out var item, out var itemGeneration))
                {
                    break;
                }

                bool sent;
                try
                {
                    sent = send(item.Message);
                }
                catch
                {
                    sent = false;
                }

                if (!sent)
                {
                    failures++;
                    RequeueFront(item, itemGeneration);
                    break;
                }

                lock (sync)
                {
                    intervalSentPayloads++;
                    intervalSentBytes += item.EstimatedBytes;
                }

                sentPayloads++;
                sentBytes += item.EstimatedBytes;
            }

            if (failures > 0)
            {
                lock (sync)
                {
                    intervalSendFailures += failures;
                }
            }

            return new IGPOutgoingReliableDrainResult(sentPayloads, sentBytes, failures);
        }

        public void RemovePeer(string remotePlayerId)
        {
            if (string.IsNullOrWhiteSpace(remotePlayerId))
            {
                return;
            }

            lock (sync)
            {
                if (!queues.TryGetValue(remotePlayerId, out var queue))
                {
                    activePeerSet.Remove(remotePlayerId);
                    return;
                }

                foreach (var item in queue)
                {
                    queuedPayloads--;
                    queuedBytes -= item.EstimatedBytes;
                }

                queues.Remove(remotePlayerId);
                queuedBytesByPeer.Remove(remotePlayerId);
                activePeerSet.Remove(remotePlayerId);
                generation++;
            }
        }

        public void Clear()
        {
            lock (sync)
            {
                queues.Clear();
                queuedBytesByPeer.Clear();
                activePeers.Clear();
                activePeerSet.Clear();
                queuedPayloads = 0;
                queuedBytes = 0;
                throttledByWaitSnd = false;
                throttleCode = "none";
                lastWaitSnd = -1;
                tokenBucketBytes = Math.Max(0, OutgoingBytesPerSecond);
                lastTokenTime = -1f;
                generation++;
            }
        }

        public IGPOutgoingReliableSchedulerSnapshot ConsumeDiagnosticsSnapshot()
        {
            lock (sync)
            {
                var snapshot = new IGPOutgoingReliableSchedulerSnapshot(
                    queuedPayloads,
                    queuedBytes,
                    totalDroppedPayloads,
                    totalDroppedBytes,
                    intervalEnqueuedPayloads,
                    intervalEnqueuedBytes,
                    intervalSentPayloads,
                    intervalSentBytes,
                    intervalDroppedPayloads,
                    intervalDroppedBytes,
                    intervalSendFailures,
                    throttledByWaitSnd,
                    throttleCode,
                    lastWaitSnd);

                intervalEnqueuedPayloads = 0;
                intervalEnqueuedBytes = 0;
                intervalSentPayloads = 0;
                intervalSentBytes = 0;
                intervalDroppedPayloads = 0;
                intervalDroppedBytes = 0;
                intervalSendFailures = 0;
                return snapshot;
            }
        }

        private bool TryDequeueReady(int currentWaitSnd, out QueueItem item, out long itemGeneration)
        {
            lock (sync)
            {
                RefillTokens();
                UpdateWaitSndThrottle(currentWaitSnd);

                int peersToCheck = activePeers.Count;
                while (peersToCheck > 0)
                {
                    string remotePlayerId = activePeers.Dequeue();
                    activePeerSet.Remove(remotePlayerId);
                    peersToCheck--;

                    if (!queues.TryGetValue(remotePlayerId, out var queue) || queue.First == null)
                    {
                        queues.Remove(remotePlayerId);
                        queuedBytesByPeer.Remove(remotePlayerId);
                        continue;
                    }

                    var front = queue.First.Value;
                    if (ShouldThrottle(front))
                    {
                        AddActivePeer(remotePlayerId);
                        item = default;
                        itemGeneration = generation;
                        return false;
                    }

                    if (front.Priority >= IGPOutgoingPriority.ReliableData &&
                        tokenBucketBytes < front.EstimatedBytes)
                    {
                        throttleCode = "tokenBudget";
                        AddActivePeer(remotePlayerId);
                        item = default;
                        itemGeneration = generation;
                        return false;
                    }

                    queue.RemoveFirst();
                    queuedPayloads--;
                    queuedBytes -= front.EstimatedBytes;
                    queuedBytesByPeer[remotePlayerId] -= front.EstimatedBytes;
                    if (front.Priority >= IGPOutgoingPriority.ReliableData)
                    {
                        tokenBucketBytes -= front.EstimatedBytes;
                    }

                    if (queue.Count > 0)
                    {
                        AddActivePeer(remotePlayerId);
                    }
                    else
                    {
                        queues.Remove(remotePlayerId);
                        queuedBytesByPeer.Remove(remotePlayerId);
                    }

                    item = front;
                    itemGeneration = generation;
                    return true;
                }

                item = default;
                itemGeneration = generation;
                return false;
            }
        }

        private void RequeueFront(QueueItem item, long itemGeneration)
        {
            lock (sync)
            {
                if (itemGeneration != generation)
                {
                    return;
                }

                if (!queues.TryGetValue(item.RemotePlayerId, out var queue))
                {
                    queue = new LinkedList<QueueItem>();
                    queues[item.RemotePlayerId] = queue;
                    queuedBytesByPeer[item.RemotePlayerId] = 0;
                }

                queue.AddFirst(item);
                queuedPayloads++;
                queuedBytes += item.EstimatedBytes;
                queuedBytesByPeer[item.RemotePlayerId] += item.EstimatedBytes;
                if (item.Priority >= IGPOutgoingPriority.ReliableData)
                {
                    tokenBucketBytes += item.EstimatedBytes;
                    float maxBucket = Math.Max(0, OutgoingBytesPerSecond);
                    if (tokenBucketBytes > maxBucket)
                    {
                        tokenBucketBytes = maxBucket;
                    }
                }

                AddActivePeer(item.RemotePlayerId);
            }
        }

        private void RefillTokens()
        {
            float now = timeProvider();
            float maxBucket = Math.Max(0, OutgoingBytesPerSecond);
            if (lastTokenTime < 0f)
            {
                tokenBucketBytes = maxBucket;
                lastTokenTime = now;
                return;
            }

            float delta = Math.Max(0f, now - lastTokenTime);
            tokenBucketBytes = Math.Min(maxBucket, tokenBucketBytes + delta * maxBucket);
            lastTokenTime = now;
        }

        private void UpdateWaitSndThrottle(int currentWaitSnd)
        {
            lastWaitSnd = currentWaitSnd;
            if (currentWaitSnd < 0)
            {
                throttledByWaitSnd = false;
                throttleCode = "none";
                return;
            }

            if (throttledByWaitSnd)
            {
                if (currentWaitSnd <= WaitSndLowWatermark)
                {
                    throttledByWaitSnd = false;
                    throttleCode = "none";
                }
                else
                {
                    throttleCode = "waitSndHysteresis";
                }
            }
            else if (currentWaitSnd > WaitSndHighWatermark)
            {
                throttledByWaitSnd = true;
                throttleCode = "waitSndHigh";
            }
            else
            {
                throttleCode = "none";
            }
        }

        private bool ShouldThrottle(QueueItem item)
        {
            return throttledByWaitSnd && item.Priority >= IGPOutgoingPriority.ReliableData;
        }

        private void AddActivePeer(string remotePlayerId)
        {
            if (activePeerSet.Add(remotePlayerId))
            {
                activePeers.Enqueue(remotePlayerId);
            }
        }

        private void RecordDrop(long payloads, long bytes)
        {
            totalDroppedPayloads += payloads;
            totalDroppedBytes += bytes;
            intervalDroppedPayloads += payloads;
            intervalDroppedBytes += bytes;
        }

        private readonly struct QueueItem
        {
            public QueueItem(
                string remotePlayerId,
                Message message,
                int estimatedBytes,
                IGPOutgoingPriority priority)
            {
                RemotePlayerId = remotePlayerId;
                Message = message;
                EstimatedBytes = Math.Max(0, estimatedBytes);
                Priority = priority;
            }

            public string RemotePlayerId { get; }
            public Message Message { get; }
            public int EstimatedBytes { get; }
            public IGPOutgoingPriority Priority { get; }
        }
    }

    internal readonly struct IGPOutgoingReliablePayload
    {
        public IGPOutgoingReliablePayload(Message message, int estimatedBytes, IGPOutgoingPriority priority)
        {
            Message = message ?? throw new ArgumentNullException(nameof(message));
            EstimatedBytes = Math.Max(0, estimatedBytes);
            Priority = priority;
        }

        public Message Message { get; }
        public int EstimatedBytes { get; }
        public IGPOutgoingPriority Priority { get; }
    }

    internal readonly struct IGPOutgoingReliableDrainResult
    {
        public IGPOutgoingReliableDrainResult(int sentPayloads, long sentBytes, int sendFailures)
        {
            SentPayloads = sentPayloads;
            SentBytes = sentBytes;
            SendFailures = sendFailures;
        }

        public int SentPayloads { get; }
        public long SentBytes { get; }
        public int SendFailures { get; }
    }

    internal readonly struct IGPOutgoingReliableSchedulerSnapshot
    {
        public IGPOutgoingReliableSchedulerSnapshot(
            int queuedPayloads,
            long queuedBytes,
            long totalDroppedPayloads,
            long totalDroppedBytes,
            long intervalEnqueuedPayloads,
            long intervalEnqueuedBytes,
            long intervalSentPayloads,
            long intervalSentBytes,
            long intervalDroppedPayloads,
            long intervalDroppedBytes,
            long intervalSendFailures,
            bool isThrottled,
            string throttleCode,
            int lastWaitSnd)
        {
            QueuedPayloads = queuedPayloads;
            QueuedBytes = queuedBytes;
            TotalDroppedPayloads = totalDroppedPayloads;
            TotalDroppedBytes = totalDroppedBytes;
            IntervalEnqueuedPayloads = intervalEnqueuedPayloads;
            IntervalEnqueuedBytes = intervalEnqueuedBytes;
            IntervalSentPayloads = intervalSentPayloads;
            IntervalSentBytes = intervalSentBytes;
            IntervalDroppedPayloads = intervalDroppedPayloads;
            IntervalDroppedBytes = intervalDroppedBytes;
            IntervalSendFailures = intervalSendFailures;
            IsThrottled = isThrottled;
            ThrottleCode = string.IsNullOrWhiteSpace(throttleCode) ? "none" : throttleCode;
            LastWaitSnd = lastWaitSnd;
        }

        public int QueuedPayloads { get; }
        public long QueuedBytes { get; }
        public long TotalDroppedPayloads { get; }
        public long TotalDroppedBytes { get; }
        public long IntervalEnqueuedPayloads { get; }
        public long IntervalEnqueuedBytes { get; }
        public long IntervalSentPayloads { get; }
        public long IntervalSentBytes { get; }
        public long IntervalDroppedPayloads { get; }
        public long IntervalDroppedBytes { get; }
        public long IntervalSendFailures { get; }
        public bool IsThrottled { get; }
        public string ThrottleCode { get; }
        public int LastWaitSnd { get; }
    }
}
