﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using PacMap.Packet;
using PacMap.General;
using System.IO;
using PacMap.Anonymize;
using System.Drawing;

namespace PacMap.Protocols
{
    public class IPv4 : Protocol
    {
        /// <summary>
        /// Constructs IPv4 Class
        /// </summary>
        /// <param name="input">bytes array</param>
        /// <param name="packet">packet frame</param>
        public IPv4(byte[] input, Packet.Packet packet) : base(input, packet)
        {
            shortName = "IPv4";
            Stream stream = new MemoryStream(input);

            #region GET IPv4 Intel
            //
            // Check Internal Header Length (IHL)
            // 
            {
                int test = stream.ReadByte();
                ihl = (test & 0x0F) << 2;
            }


            //
            // Get Differentiated Services Code Point (DSCP) & Explicit Congestion Notification (ECN) 
            //
            {
                int test = stream.ReadByte();
                dscp = (test & 0xFC) >> 2;
                ecn = (test & 0x04);
            }


            //
            // Get Total Length (TL)
            // 
            {
                byte[] bytes = new byte[2];
                stream.Read(bytes, 0, 2);
                tl = (int)(bytes[0] << 8) + (int)bytes[1];
            }

            //
            // Get Identification
            //
            {
                byte[] bytes = new byte[2];
                stream.Read(bytes, 0, 2);
                identification = (int)(bytes[0] << 8) + (int)bytes[1];
            }

            //
            // Get Flags & Fragment Offset (FO)
            //
            {
                byte[] bytes = new byte[2];
                stream.Read(bytes, 0, 2);
                int word = (int)(bytes[0] << 8) + (int)bytes[1];
                int mask = 0x07 << 13;
                flags = (word & mask) >> 13;
                mask = 0xFFFF - mask;
                fo = word & mask;
            }

            //
            // Get Time To Live (TTL)
            //
            {
                ttl = stream.ReadByte();
            }


            //
            // Get Protocol
            //
            {
                int protocolNo = stream.ReadByte();
                protocol = Protocols.ProtocolsInfo.ProtocolsNo[protocolNo];
                packet.Protocol = protocol.ToString();
            }

            //
            // Get Header Checksum & Verify
            //
            {
                byte[] bytes = new byte[2];
                stream.Read(bytes, 0, 2);
                hChecksum = (int)(bytes[0] << 8) + (int)bytes[1];
            }

            //
            // Get Left IP Adress
            //
            {
                byte[] bytes = new byte[4];
                stream.Read(bytes, 0, 4);
                string left = bytes.ConvertIPv4ToString();
                SourceIP = left;
            }


            //
            // Get Right IP Adress
            //
            {
                byte[] bytes = new byte[4];
                stream.Read(bytes, 0, 4);
                string right = bytes.ConvertIPv4ToString();
                DestinationIP = right;
            }



            //
            // Get Additional Header Options
            //
            {
                int moreSize = ihl - 20;
                if (moreSize > 0)
                {
                    IList<byte> bytes = new List<byte>();
                    for (int i = 1; i <= moreSize; i++)
                    {
                        bytes.Add((byte)stream.ReadByte());
                    }
                    hOptions = bytes.ToArray();
                }
            }
            #endregion

            //
            // Verify header Checksum
            //
            {
                if (hChecksum == CorrectHeaderChecksum)
                    isHChecksum = true;
            }

            packet.IPSource = SourceIP;
            packet.IPDestination = DestinationIP;
            packet.ProtocolNo = Protocol.GetProtocolNo();
            //--------------------------------------------
            int restSize = (int)(stream.Length - stream.Position);
            byte[] rest = new byte[restSize];
            int readed = stream.Read(rest, 0, restSize);

            if (readed == 0)
                return;
            
            child = Protocols.ProtocolsInfo.CreateAppropriateProtocolContent(protocol, 0, 0, packet, rest); // Encapsulate Inner Protocol
        }



        /******************************************************************************************/
        // :: DATA FIELDs ::
        //private bool isEth2 = false; // STANDARD INTERNET PROTOCOL - ETHERNET II
        //private byte[] direction = new byte[12]; // destination & source MAC addresses
        private int ihl = 0;
        private int dscp = 0;
        private int ecn = 0;
        private int tl = 0;
        private int identification = 0;
        private int flags = 0;
        private int fo = 0;
        private int ttl = 0;
        private ProtocolType protocol = 0;
        private int hChecksum = 0;
        private bool isHChecksum = false;
        private string localIP = "";
        private string remoteIP = "";
        private byte[] hOptions = null;

        private string lastError = "";


        /******************************************************************************************/
        // :: PROPERTIEs ::
        #region PROPERTIEs
        public Direction Direction { get; private set; }
        public bool AnyError
        {
            get
            {
                bool bresult = true;
                if (lastError == "")
                    bresult = false;
                return bresult;
            }
        }

        /// <summary>
        /// Get Internal Header Length
        /// </summary>
        public int IHL
        {
            get
            {
                return ihl;
            }
        }

        /// <summary>
        /// Get Differentiated Services Code Point
        /// </summary>
        public int DSCP
        {
            get
            {
                return dscp;
            }
        }

        /// <summary>
        /// Get Explicit Congestion Notification
        /// </summary>
        public int ECN
        {
            get
            {
                return ecn;
            }
        }

        /// <summary>
        /// Get Packet Total Length (expected one)
        /// </summary>
        public int TL
        {
            get
            {
                return tl;
            }
        }

        /// <summary>
        /// Get Packet Captured Length (true one)
        /// </summary>
        public int CTL
        {
            get
            {
                return packet.Frame.CapturedPacketSize;
            }
        }

        /// <summary>
        /// Get Identification
        /// </summary>
        public int ID
        {
            get
            {
                return identification;
            }
        }

        /// <summary>
        /// Get Flags
        /// </summary>
        public int FLAGS
        {
            get
            {
                return flags;
            }
        }

        /// <summary>
        /// Get Fragment offset
        /// </summary>
        public int FO
        {
            get
            {
                return fo;
            }
        }

        /// <summary>
        /// Get Time To Live
        /// </summary>
        public int TTL
        {
            get
            {
                return ttl;
            }
        }

        /// <summary>
        ///  Get Protocol Type
        /// </summary>
        public ProtocolType Protocol
        {
            get
            {
                return protocol;
            }
        }

        /// <summary>
        /// Get or Set Local IP Address
        /// </summary>
        public string IP
        {
            get
            {
                return localIP;
            }
            set
            {
                localIP = value;
            }
        }

        /// <summary>
        /// Get or Set Remote IP Address
        /// </summary>
        public string RemoteIP
        {
            get
            {
                return remoteIP;
            }
            set
            {
                remoteIP = value;
            }
        }

        /// <summary>
        /// Get or Set Source IP Address (Left Address)
        /// </summary>
        public string SourceIP
        {
            get
            {
                if (packet.Direction == PacMap.Packet.Direction.LocalToRemote)
                    return localIP;
                else
                    return remoteIP;
            }
            set
            {
                packet.IPSource = value;
                if (packet.Direction == PacMap.Packet.Direction.LocalToRemote)
                    localIP = value;
                else
                    remoteIP = value;
            }
        }

        /// <summary>
        /// Get or Set Destination IP Address (Right Address)
        /// </summary>
        public string DestinationIP
        {
            get
            {
                if (packet.Direction == PacMap.Packet.Direction.RemoteToLocal)
                    return localIP;
                else
                    return remoteIP;
            }
            set
            {
                packet.IPDestination = value;
                if (packet.Direction == PacMap.Packet.Direction.RemoteToLocal)
                    localIP = value;
                else
                    remoteIP = value;
            }
        }

        /// <summary>
        /// Count Correct Header Checksum
        /// </summary>
        private int CorrectHeaderChecksum
        {
            get
            {
                int result = 0;

                byte[] bytes = new byte[ihl];
                bytes = CreateHeader(false);
                IList<int> values = new List<int>();
                for (int i = 1; i <= (ihl / 2); i++)
                {
                    int left = bytes[(i - 1) * 2];
                    int right = bytes[(i - 1) * 2 + 1];
                    int dword = (left << 8) | right;
                    if (i != 6) // 11th & 12th bytes are checksum bytes = their skipping
                        values.Add(dword);
                }

                long sum = values.Sum();
                int carryByte = (int)((sum & 0xFF0000) >> 16);
                int valueByte = (int)(sum & 0xFFFF);
                result = carryByte + valueByte;
                result = ((~result) & 0xFFFF);
                return result;
            }
        }

        /// <summary>
        /// Get Original Header Checksum
        /// </summary>
        public int OriginalHeaderChecksum
        {
            get
            {
                return hChecksum;
            }
        }

        /// <summary>
        /// Check Header Checksum is Correct
        /// </summary>
        public bool IsHeaderCheckumCorrect
        {
            get
            {
                return isHChecksum;
            }
        }
        #endregion



        /******************************************************************************************/
        // :: FUNCTIONs ::


        /// <summary>
        /// Get Actual IPv4 Header in Bytes (advanced method)
        /// </summary>
        /// <param name="exceptChecksums">creating checkum byte fields is supported</param>
        /// <returns>Array of bytes (20 - 60B)</returns>
        private byte[] CreateHeader(bool exceptChecksums)
        {
            byte[] result = new byte[ihl];
            Stream stream = new MemoryStream(result);
            stream.Position = 0;

            if (AnyError)
                return result;

            int word = 4;

            word = (word << 4) | (ihl >> 2);
            stream.WriteByte((byte)word);

            word = (dscp << 2) | (ecn);
            stream.WriteByte((byte)word);

            stream.WriteByte((byte)((tl & 0xFF00) >> 8));
            stream.WriteByte((byte)(tl & 0x00FF));

            stream.WriteByte((byte)((identification & 0xFF00) >> 8));
            stream.WriteByte((byte)(identification & 0x00FF));

            stream.WriteByte((byte)((flags << 5) + ((fo & 0x3700) >> 8)));
            stream.WriteByte((byte)(fo & 0x00FF));

            stream.WriteByte((byte)ttl);
            int ptclno = Protocols.ProtocolsInfo.ProtocolsNo.Where(item => item.Value == protocol).First().Key;
            stream.WriteByte((byte)ptclno);


            if (exceptChecksums)
            {
                int newChecksum = 0;
                if (isHChecksum)
                {                                                  // if header is correct and ...
                    if (AppSettings.RecalculateChecksums) // ...recalculating checksum is supported =>
                    {
                        newChecksum = CorrectHeaderChecksum;              // => count new header checksum
                    }
                    else                                           // else => use original header checksum
                        newChecksum = OriginalHeaderChecksum;
                }
                else                                            //------------------------------------------------//
                {                                                  // if header is bad and ...
                    if (AppSettings.SetBadChecksumsToBad)  // ...bad to bad is supported =>
                    {
                        newChecksum = 0;                           // => header is zero
                    }
                    else
                        newChecksum = OriginalHeaderChecksum;      // else => use original header checksum
                }

                stream.WriteByte((byte)((newChecksum & 0xFF00) >> 8));
                stream.WriteByte((byte)(newChecksum & 0x00FF));
            }
            else
            {
                stream.Position = stream.Position + 2;
            }


            {
                stream.Write(SourceIP.ConvertIPv4ToArray(), 0, 4);
                stream.Write(DestinationIP.ConvertIPv4ToArray(), 0, 4);
            }


            {
                int moreSize = ihl - 20;
                if (moreSize > 0)
                {
                    for (int i = 1; i <= moreSize; i++)
                    {
                        result[20 + i - 1] = hOptions[i - 1];
                    }
                }
            }

            return result;
        }

        /// <summary>
        /// Get Actual IPv4 Header in Bytes
        /// </summary>
        /// <returns>Array of bytes (20 - 60B)</returns>
        private byte[] CreateHeader()
        {
            return CreateHeader(true);
        }

        


        

        


        /******************************************************************************************/
        // :: GLOBAL METHODs ::


        /// <summary>
        /// Start AnonMapping Protocol
        /// </summary>
        public override void Anonymize()
        {
            #region AnomMapping IPv4 Header
            object ___temp__localIP = null;
            object ___temp__remoteIP = null;
            object ___temp__sourceIP = null;
            object ___temp__destinationIP = null;
            object ___temp__tos = null;
            object ___temp__identification = null;
            object ___temp__flags = null;
            object ___temp__fragment = null;

            //
            // AnonMapping
            //
            Anonymizator anonymizator = new Anonymizator();
            anonymizator.AnonMap(IP, out ___temp__localIP, packet.AnonSettings.IPv4.Local, packet.AnonSettings.IPv4.Local_ReplaceWith, AnonimizatorInput.IPv4);
            anonymizator.AnonMap(RemoteIP, out ___temp__remoteIP, packet.AnonSettings.IPv4.Remote, packet.AnonSettings.IPv4.Remote_ReplaceWith, AnonimizatorInput.IPv4);
            anonymizator.AnonMap(SourceIP, out ___temp__sourceIP, packet.AnonSettings.IPv4.Source, packet.AnonSettings.IPv4.Source_ReplaceWith, AnonimizatorInput.IPv4);
            anonymizator.AnonMap(DestinationIP, out ___temp__destinationIP, packet.AnonSettings.IPv4.Destination, packet.AnonSettings.IPv4.Destination_ReplaceWith, AnonimizatorInput.IPv4);

            int tos = (DSCP << 2) | (ECN);
            anonymizator.AnonMap(tos, out ___temp__tos, packet.AnonSettings.IPv4.Tos, packet.AnonSettings.IPv4.Tos_ReplaceWith, AnonimizatorInput.BYTE);
            anonymizator.AnonMap(ID, out ___temp__identification, packet.AnonSettings.IPv4.Identification, packet.AnonSettings.IPv4.Identification_ReplaceWith, AnonimizatorInput.WORD);
            anonymizator.AnonMap(FLAGS, out ___temp__flags, packet.AnonSettings.IPv4.Flags, packet.AnonSettings.IPv4.Flags_ReplaceWith, AnonimizatorInput.BYTE);
            anonymizator.AnonMap(FO, out ___temp__fragment, packet.AnonSettings.IPv4.FragmentOffset, packet.AnonSettings.IPv4.FragmentOffset_ReplaceWith, AnonimizatorInput.WORD);


            if ((packet.AnonSettings.IPv4.Source != AnonymizationType.None) || (packet.AnonSettings.IPv4.Destination != AnonymizationType.None))
            {
                SourceIP = packet.AnonSettings.IPv4.Source_Mask.Proceed(SourceIP.ConvertIPv4ToArray(), ((string)___temp__sourceIP).ConvertIPv4ToArray()).ConvertIPv4ToString();
                DestinationIP = packet.AnonSettings.IPv4.Destination_Mask.Proceed(DestinationIP.ConvertIPv4ToArray(), ((string)___temp__destinationIP).ConvertIPv4ToArray()).ConvertIPv4ToString();
            }
            if ((packet.AnonSettings.IPv4.Local != AnonymizationType.None) || (packet.AnonSettings.IPv4.Remote != AnonymizationType.None))
            {
                IP = packet.AnonSettings.IPv4.Local_Mask.Proceed(IP.ConvertIPv4ToArray(), ((string)___temp__localIP).ConvertIPv4ToArray()).ConvertIPv4ToString();
                RemoteIP = packet.AnonSettings.IPv4.Remote_Mask.Proceed(RemoteIP.ConvertIPv4ToArray(), ((string)___temp__remoteIP).ConvertIPv4ToArray()).ConvertIPv4ToString();
            }

            tos = (int)___temp__tos;
            dscp = (tos & 0xFC) >> 2;
            ecn = (tos & 0x03);
            identification = (int)___temp__identification;
            flags = ((int)___temp__flags & 0x07);
            fo = ((int)___temp__fragment & 0x1FFF);
            ttl = packet.AnonSettings.IPv4.Ttl.Run(ttl);

            packet.IPSource = SourceIP;
            packet.IPDestination = DestinationIP;

            //
            // Check IPv4 Header
            //
            if ((!IsHeaderCheckumCorrect) && (AppSettings.SetBadChecksumsToBad))
                hChecksum = 0;
            if (AppSettings.RecalculateChecksums)
                hChecksum = CorrectHeaderChecksum;
            #endregion


            if (HasChild)
            {
                if (packet.AnonSettings.IPv4.Content == AnonContentType.Standard) // default settings
                    child.Anonymize();
                else
                {
                    byte[] content = child.Save();
                    object contentOut = new byte[content.Length];
                    anonymizator.AnonMap(content, out contentOut, packet.AnonSettings.IPv4.Content.ConvertToAnonymizationType(), packet.AnonSettings.IPv4.Content_ReplaceWith, content.Length);
                    content = (byte[])contentOut;
                    child = new Protocol(content, packet);
                }
            }
        }


        /// <summary>
        /// Save Protocol to bytes array
        /// </summary>
        /// <returns>bytes array</returns>
        public override byte[] Save()
        {
            byte[] childOutput = null;
            if (HasChild)
            {
                childOutput = child.Save();
            }

            // --------------------------

            byte[] result = null;
            if ((HasChild) && (childOutput != null))
            {
                result = new byte[IHL + childOutput.Length];
                tl = result.Length;
            }
            else
                result = new byte[content.Length];

            Stream stream = new MemoryStream(result);
            stream.Position = 0;

            byte[] header = CreateHeader();
            stream.Write(header, 0, header.Length);

            // ------------------------

            if ((HasChild) && (childOutput != null))
            {
                stream.Write(childOutput, 0, childOutput.Length);
            }


            return result;

        }

        /// <summary>
        /// Get Full Packet Information for Advanced Packet Viewing
        /// </summary>
        /// <returns>Specialized Information Object</returns>
        public override object GetFullInformation()
        {
            IList<object> result = new List<object>();
            IDictionary<string, object> browserInfo = new Dictionary<string, object>();
            IList<KeyValuePair<Point, object>> browserHex = new List<KeyValuePair<Point, object>>();
            IList<string> browserGroup = new List<string>();
            result.Add(browserInfo);
            result.Add(browserHex);
            result.Add(browserGroup);

            #region __GET NAMEs
            {
                string name = String.Format("Internet Protocol Version 4, Src: {0}, Dst: {1}", SourceIP, DestinationIP);
                string item1 = String.Format("Version: 4");
                string item2 = String.Format("Header length: {0}", ihl);
                string dscpItem = String.Format("{0:X}", dscp);
                if (dscpItem.Length == 1)
                    dscpItem = "0" + dscpItem;
                string ecnItem = String.Format("{0:X}", ecn);
                if (ecnItem.Length == 1)
                    ecnItem = "0" + ecnItem;
                string item3 = String.Format("Differentiated Services Field: (DSCP 0x{0}, ECN 0x{1})", dscpItem.ToLower(), ecnItem.ToLower());
                string item4 = String.Format("Total length: {0}", tl);
                string item5 = String.Format("Identification: {0}", identification);
                string item6 = String.Format("Flags: ");
                string item6Content = String.Format("{0:X}", flags);
                if (item6Content.Length == 1)
                    item6Content = "0" + item6Content;
                item6 = String.Format("{0}0x{1}", item6, item6Content.ToLower());

                string item6_1 = "0";
                string item6_2 = "0";
                string item6_3 = "0";
                if ((flags & 0x4) == 0x4)
                    item6_1 = "1";
                if ((flags & 0x2) == 0x2)
                    item6_2 = "1";
                if ((flags & 0x1) == 0x1)
                    item6_3 = "1";
                item6_1 = String.Format("{0}... .... = Reserved bit: {1}", item6_1, (item6_1 == "0") ? "Not set" : "Set");
                item6_2 = String.Format(".{0}.. .... = Don't fragment: {1}", item6_2, (item6_2 == "0") ? "Not set" : "Set");
                item6_3 = String.Format("..{0}. .... = More fragments: {1}", item6_3, (item6_3 == "0") ? "Not set" : "Set");
                IDictionary<string, object> item6Inner = new Dictionary<string, object>();
                item6Inner[item6_1] = null;
                item6Inner[item6_2] = null;
                item6Inner[item6_3] = null;
                string item7 = String.Format("Fragment offset: {0}", fo);
                string item8 = String.Format("Time to live: {0}", ttl);
                string item9 = String.Format("Protocol: {0}", protocol);
                string item10 = String.Format("Header checksum: ");
                string item10left = String.Format("{0:X}", ((OriginalHeaderChecksum & 0xFF00) >> 8));
                if (item10left.Length == 1)
                    item10left = "0" + item10left;
                string item10right = String.Format("{0:X}", (OriginalHeaderChecksum & 0x00FF));
                if (item10right.Length == 1)
                    item10right = "0" + item10right;
                item10 += "0x" + item10left.ToLower() + item10right.ToLower();
                if (IsHeaderCheckumCorrect)
                    item10 += " [correct]";
                else
                    item10 += " [bad]";

                string item11 = String.Format("Source: {0}", SourceIP);
                string item12 = String.Format("Destination: {0}", DestinationIP);

                IDictionary<string, object> inner = new Dictionary<string, object>();
                inner[item1] = null;
                inner[item2] = null;
                inner[item3] = null;
                inner[item4] = null;
                inner[item5] = null;
                inner[item6] = item6Inner;
                inner[item7] = null;
                inner[item8] = null;
                inner[item9] = null;
                inner[item10] = null;
                inner[item11] = null;
                inner[item12] = null;
                browserInfo[name] = inner;
            }
            #endregion

            #region __GET HEX
            {
                Point nameHex = new Point(1, ihl);
                Point item1Hex = new Point(1, 1);
                Point item2Hex = new Point(1, 1);
                Point item3Hex = new Point(2, 2);
                Point item4Hex = new Point(3, 4);
                Point item5Hex = new Point(5, 6);
                Point item6Hex = new Point(7, 7);
                Point item7Hex = new Point(7, 8);
                Point item8Hex = new Point(9, 9);
                Point item9Hex = new Point(10, 10);
                Point item10Hex = new Point(11, 12);
                Point item11Hex = new Point(13, 16);
                Point item12Hex = new Point(17, 20);
                IList<KeyValuePair<Point, object>> innerHex = new List<KeyValuePair<Point, object>>();
                innerHex.Add(new KeyValuePair<Point, object>(item1Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item2Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item3Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item4Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item5Hex, null));
                Point item6Hex1to3 = new Point(7, 7);
                IList<KeyValuePair<Point, object>> item6HexContent = new List<KeyValuePair<Point, object>>();
                item6HexContent.Add(new KeyValuePair<Point, object>(item6Hex1to3, null));
                item6HexContent.Add(new KeyValuePair<Point, object>(item6Hex1to3, null));
                item6HexContent.Add(new KeyValuePair<Point, object>(item6Hex1to3, null));
                innerHex.Add(new KeyValuePair<Point, object>(item6Hex, item6HexContent));
                innerHex.Add(new KeyValuePair<Point, object>(item7Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item8Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item9Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item10Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item11Hex, null));
                innerHex.Add(new KeyValuePair<Point, object>(item12Hex, null));
                browserHex.Add(new KeyValuePair<Point, object>(nameHex, innerHex));
            }
            #endregion

            #region __GET GROUP
            {
                browserGroup.Add(shortName);
            }
            #endregion

            return result;
        }



        public override string ToString()
        {
            string result = "";

            string leftIP = "";
            if (Direction == PacMap.Packet.Direction.LocalToRemote)
                leftIP = leftIP.Insert(0, localIP);
            else
                leftIP = leftIP.Insert(0, remoteIP);

            string rightIP = "";
            if (Direction == PacMap.Packet.Direction.RemoteToLocal)
                rightIP = rightIP.Insert(0, localIP);
            else
                rightIP = rightIP.Insert(0, remoteIP);

            string ptcl = "    ";
            ptcl = ptcl.Insert(0, protocol.ToString());


            result = String.Format("{0} -> {1} {2} {3}B", leftIP, rightIP, protocol, packet.Frame.OriginalPacketSize);
            return result;
        }




    }
}
