#include "utility.h"
#include <ctype.h>

namespace ADNS
{
	UInt32 DateTimeToTimeStamp(DateTime^ dt)
	{
		DateTime^ epoch = gcnew DateTime(1970,1,1,0,0,0);
		TimeSpan^ ts = gcnew TimeSpan();
		unsigned int ft32 = 0;

		ft32 = (dt->Ticks - epoch->Ticks) / ts->TicksPerSecond;

		return ft32;
	}

	DateTime^ TimeStampToDateTime(UInt32 tsa)
	{
		DateTime^ epoch = gcnew DateTime(1970,1,1,0,0,0);
		TimeSpan^ ts = gcnew TimeSpan();
		DateTime^ newtime;

		newtime = gcnew DateTime((ts->TicksPerSecond * tsa) + epoch->Ticks);

		return newtime;
	}

	Void SetBit(array<Byte>^ bm, unsigned int bit, bool val)
	{
		int bitbyte = 0;
		Byte bitset = 1;
		
		bitbyte = bit / 8;
		
		bitset = bitset << (7 - (bit % 8));

		if (val == true)
			bm[bitbyte] = bm[bitbyte] | bitset;
		else
			bm[bitbyte] = bm[bitbyte] & (bitset ^ 0xFF);

		return;
	}

	Byte find_last_non_zero_byte(array<Byte>^ bitmap, int startpos, int checklength)
	{
		Byte i = 0;
		Byte last = -1;

		for (i = 0; i < checklength; ++i)
		{
			if (i != 0)
				last = i;
		}

		return last;
	}

	int find_length_of_domain_label(array<Byte>^ pkt, int startpos)
	{
		int pos = startpos;
		int length = 0;

		while (pos < pkt->Length)
		{
			if (pkt[pos] == 0)
			{
				break;
			}
			if ((pkt[pos] & 0xC0) == 0xC0)
			{
				pos = ((pkt[pos] & 0x3F) << 8) + pkt[pos+1];
			}
			else
			{
				length += pkt[pos] + 1;
				pos += pkt[pos] + 1; //+1 to skip to the next length byte.
			}
		}

		if (pos >= pkt->Length)
			return -1;

		return length;
	}

	//This function takes into account domain name compression, at least when decoding.
	array<Byte>^ ReadDomainFromPacket(array<Byte>^ packet, int startpos, int &reallen)
	{
		array<Byte>^ output;
		int len = 0;
		int pos = 0;
		int outputpos = 0;
		bool pointer = false;  //only bytes up to (and including) the first pointer count towards real length

		reallen = 0;

		len = find_length_of_domain_label(packet, startpos);
		//Console::WriteLine("Found domain label of size " + Convert::ToString(len) + "at position " + Convert::ToString(startpos));

		if (len == -1)
		{
			reallen = -1;
			return nullptr;
		}
		len += 1;  //account for ending 0x00 byte.

		output = gcnew array<Byte>(len);
		output->Clear(output,0,len);
		pos = startpos;

		while (packet[pos] != 0)
		{
			if ((packet[pos] & 0xC0) == 0xC0)
			{
				if (!pointer)
				{
					pointer = true; //stop counting at the first pointer.
					reallen += 2;  //skip over the pointer and offset
				}
				pos = ((packet[pos] & 0x3F) << 8) + packet[pos+1];
			}
			else
			{
				if (!pointer)
					reallen += packet[pos] + 1; //+1 to skip to include length byte.
				packet->Copy(packet,pos,output,outputpos,packet[pos] + 1);
				outputpos += packet[pos] + 1;
				pos += packet[pos] + 1;
			}
		}
		
		if ((packet[pos] == 0) && (!pointer))
			reallen += 1;  //skip over final 0x00

		return output;
	}

	array<Byte>^ Base32Encode(array<Byte>^ src)
	{
		array<Byte>^ b32chars = gcnew array<Byte>{'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','2','3','4','5','6','7'};
		array<Byte>^ output = gcnew array<Byte>(2*src->Length);
		array<Byte>^ tmp = gcnew array<Byte>(8);
		int inputpos = 0;
		int outputpos = 0;
		int srclength = 0;
		Byte Pad32 = '=';

		while (src->Length - inputpos > 5)
		{
			tmp[0] = (src[inputpos + 0] & 0xf8) >> 3;
			tmp[1] = ((src[inputpos + 0] & 0x07) << 2) + ((src[1] & 0xc0) >> 6);
			tmp[2] = (src[inputpos + 1] & 0x3e) >> 1;
			tmp[3] = ((src[inputpos + 1] & 0x01) << 4) + ((src[2] & 0xf0) >> 4);
			tmp[4] = ((src[inputpos + 2] & 0x0f) << 1) + ((src[3] & 0x80) >> 7);
			tmp[5] = (src[inputpos + 3] & 0x7c) >> 2;
			tmp[6] = ((src[inputpos + 3] & 0x03) << 3) + ((src[4] & 0xe0) >> 5);
			tmp[7] = (src[inputpos + 4] & 0x1f);
		
			output[outputpos++] = b32chars[tmp[0]];
			output[outputpos++] = b32chars[tmp[1]];
			output[outputpos++] = b32chars[tmp[2]];
			output[outputpos++] = b32chars[tmp[3]];
			output[outputpos++] = b32chars[tmp[4]];
			output[outputpos++] = b32chars[tmp[5]];
			output[outputpos++] = b32chars[tmp[6]];
			output[outputpos++] = b32chars[tmp[7]];
		
			inputpos += 5;
		}		

		//handle padding
		if (inputpos != src->Length)
		{
			srclength = src->Length - inputpos;

			/* Get what's left. */
			tmp[0] = (src[inputpos + 0] & 0xf8) >> 3;
			if (srclength >= 1) {
				tmp[1] = ((src[inputpos + 0] & 0x07) << 2) + ((src[inputpos + 1] & 0xc0) >> 6);
				tmp[2] = (src[inputpos + 1] & 0x3e) >> 1;
			}
			if (srclength >= 2) {
				tmp[3] = ((src[inputpos + 1] & 0x01) << 4) + ((src[inputpos + 2] & 0xf0) >> 4);
			}
			if (srclength >= 3) {
				tmp[4] = ((src[inputpos + 2] & 0x0f) << 1) + ((src[inputpos + 3] & 0x80) >> 7);
				tmp[5] = (src[inputpos + 3] & 0x7c) >> 2;
			}
			if (srclength >= 4) {
				tmp[6] = ((src[inputpos + 3] & 0x03) << 3) + ((src[inputpos + 4] & 0xe0) >> 5);
			}
			tmp[7] = 0;

			output[outputpos++] = b32chars[tmp[0]];
			if (srclength >= 1) {
				output[outputpos++] = b32chars[tmp[1]];
				if (srclength == 1 && tmp[2] == 0) {
					output[outputpos++] = Pad32;
				} else {
					output[outputpos++] = b32chars[output[2]];
				}
			} else {
				output[outputpos++] = Pad32;
				output[outputpos++] = Pad32;
			}
			if (srclength >= 2) {
				output[outputpos++] = b32chars[tmp[3]];
			} else {
				output[outputpos++] = Pad32;
			}
			if (srclength >= 3) {
				output[outputpos++] = b32chars[tmp[4]];
				if (srclength == 3 && tmp[5] == 0) {
					output[outputpos++] = Pad32;
				} else {
					output[outputpos++] = b32chars[output[5]];
				}
			} else {
				output[outputpos++] = Pad32;
				output[outputpos++] = Pad32;
			}
			if (srclength >= 4) {
				output[outputpos++] = b32chars[tmp[6]];
			} else {
				output[outputpos++] = Pad32;
			}
			output[outputpos++] = Pad32;
		}
		
		output->Resize(output,outputpos);
		return output;
	}

	array<Byte>^ Base32Decode(array<Byte>^ src)
	{
		array<Byte>^ b32chars = gcnew array<Byte>{'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','2','3','4','5','6','7'};
		array<Byte>^ output = gcnew array<Byte>(2*src->Length);
		array<Byte>^ tmp = gcnew array<Byte>(8);
		int inputpos = 0;
		int outputpos = 0;
		int srclength = 0;
		Byte Pad32 = '=';
		Byte ch;
		Byte pos = 0;

		int tarindex, state;
		int i = 0;

		state = 0;
		tarindex = 0;

		while (inputpos < src->Length)
		{	
			ch = tolower(src[inputpos++]);

			if (ch == Pad32)
				break;
			
			if (Char::IsWhiteSpace(ch))
			{
				continue;
			}
			
			pos = b32chars->IndexOf(b32chars,ch);
			if (pos == -1)
			{
				return nullptr;
			}

			switch(state) {

				case 0:
					output[tarindex] = pos << 3;
					state = 1;
					break;
				case 1:
					output[tarindex] |= pos >> 2;
					output[tarindex + 1] = (pos & 0x03) << 6;
					tarindex++;
					state = 2;
					break;
				case 2:
					output[tarindex] |=( pos << 1);
					state = 3;
					break;
				case 3:
					output[tarindex] |= (pos >> 4);
					output[tarindex + 1] |= (pos & 0x0f) << 4;
					tarindex++;
					state = 4;
					break;
				case 4:
					output[tarindex] |= pos >> 4;
					output[tarindex + 1] = (pos & 0x01) << 7;
					tarindex++;
					state = 5;
					break;
				case 5:
					output[tarindex] |= pos << 2;
					state = 6;
				case 6:
					output[tarindex] |= pos >> 3;
					output[tarindex+1] = (pos & 0x07) << 5;
					state = 7;
					tarindex++;
					break;
				case 7:
					output[tarindex] |= pos;
					tarindex++;
					state = 0;
				default:
					return nullptr;
			}
		}

		/*
		 * We are done decoding Base-32 chars.  Let's see if we ended
		 * on a byte boundary, and/or with erroneous trailing characters.
		 */

		if ((ch == Pad32) && (inputpos != src->Length))
		{	
			ch = src[inputpos++];
			switch (state){
				case 0:
				case 1:
					return nullptr;
				case 2:
				case 3:
					while (inputpos < src->Length)
					{
						ch = src[inputpos++];
						if (!Char::IsWhiteSpace(ch))
							break;
					}
					if (((inputpos < src->Length) && (ch =! Pad32)) || (inputpos == src->Length))
					{
						return nullptr;
					}
				case 4:
				case 5:
				case 6:
					while (inputpos < src->Length)
					{
						ch = src[inputpos++];
						if (!Char::IsWhiteSpace(ch))
							break;
					}
					if (((inputpos < src->Length) && (ch =! Pad32)) || (inputpos == src->Length))
					{
						return nullptr;
					}
				case 7:
					while (inputpos < src->Length)
					{
						ch = src[inputpos++];
						if (!Char::IsWhiteSpace(ch))
						{
							return nullptr;
						}
					}
			}
		output->Resize(output,tarindex);
		}
		return output;
	}	


	String^ GetHexChar(Byte val)
	{
		switch(val)
		{
		case 0:
		case 1:
		case 2:
		case 3:
		case 4:
		case 5:
		case 6:
		case 7:
		case 8:
		case 9:
			return Convert::ToString(val);
		case 10:
			return "A";
		case 11:
			return "B";
		case 12:
			return "C";
		case 13:
			return "D";
		case 14:
			return "E";
		case 15:
			return "F";
		default:
			return "X";
		}

		return "X";
	}

	String^ ArrayToHexString(array<Byte>^ a)
	{
		String^ output = gcnew String("");
		int i = 0;
		Byte v1;
		Byte v2;

		if (a == nullptr)
			return "-";

		if (a->Length == 0)
			return "-";


		for (i = 0; i < a->Length; ++i)
		{
			v1 = (a[i] >> 4) & 0x0F;
			v2 = a[i] & 0x0F;

			output += GetHexChar(v1);
			output += GetHexChar(v2);
		}


		return output;
	}

}