/*
 *	Decrypter
 *	By Jeremy Lennert
 *	March 10, 2001
 *	v. 4.0.2
 *
 *	decrypter.cpp
 *
 *	Reads in ciphertext and expected frequencies, then
 *	calls subordinate functions to do actual
 *	cryptanalysis.
 */

// COMPILER DIRECTIVES
#include <fstream.h>
#include <time.h>
#include <assert.h>
#include <stdlib.h>
#include "Keys.h"
#include "FreqTable.h"
#include "Checker.h"

// FUNCTION PROTOTYPES
double grade(unsigned char *plaintext, Keys *key);	// Grades with the Checker
													// class, returns result
int complete(unsigned char *map);	// Returns true if map has a decryption for
									// every letter in obunigrams
void alter(unsigned char a, unsigned char b, unsigned char *text);	// Exchange
													// two letters in a text string
unsigned char firstc(unsigned long x);	// Returns first 8 bits of x
unsigned char secondc(unsigned long x);	// Returns second 8 bits of x
unsigned char thirdc(unsigned long x);	// Returns third 8 bits of x
int random(int low, int high);	// Returns a random integer between low and high

/*
 *	- Cryptanalysis Functions -
 *
 *	unsigned char *ct	- pointer to the ciphertext
 *	unsigned chat *pt	- pointer to an array that may be overwritten
 *						  to store candidate plaintext
 *	float sec			- length of time after which to stop searching for better keys
 *	int quick			- if set, returns first key with a score that meets THRESHOLD
 *	int out				- if set, prints results to file "findings"
 *	int filter			- if set, and best key fails to meet THRESHOLD, returns a zero key
 */
Keys* xor(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter);
Keys* shift(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter);
Keys* substitution(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter);
Keys* vigenere(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter);

// GLOBAL VARIABLES
#define MAXLENGTH (5000)		// Maximum length of sample text
#define THRESHOLD (-45000.0)	// Minimum score considered a successful decryption
#define FIRSTCHAR (0x0000FF)	// Binary & yields first 8 bits
#define SECONDCHAR (0x00FF00)	// Binary & yields second 8 bits
#define THIRDCHAR (0xFF0000)	// Binary & yields third 8 bits

FreqTable unigramfreqs;	// expected unigram frequencies (x10000)
FreqTable bigramfreqs;	// expected bigram frequencies (x10000)
FreqTable bigramlogs;	// natural logs of expected bigram frequencies (x10000)
FreqTable trigramfreqs;	// expected trigram frequencies (x10000)

FreqTable obunigrams;	// observed unigram counts
FreqTable obbigrams;	// observed bigram counts
FreqTable obtrigrams;	// observed trigram counts

Checker *look;		// Checker that examines results
int length;		// Total length of ciphertext, in bytes
unsigned char cipher[MAXLENGTH+1];	// stores ciphertext
unsigned char plain[MAXLENGTH+1];	// stores candidate plaintext
ofstream findings;	// Results of attempt to decrypt

int uni[256];	// observed unigram counts
int bi[256][256];	// observed bigram counts
int tri[256][256][256];	// observed trigram counts

int forget = 0;	// If set, don't let the Checker class compare scores
				// against the best its seen


int main()
{
	ifstream ciphertext;// ciphertext
	ifstream stats;		// expected bigram frequencies, natural logs
	unsigned char lastchar, thischar, nextchar, extra;
	int x, y, z;
	Keys* key;
	time_t t;

	srand(time(&t));	// Seed pseudo-random number generator with time
	
	findings.open("findings.txt");

	/*
	 *	Read in unigram frequencies
	 */
	cout << "Reading in unigram frequencies . . . ";
	cout.flush();
	stats.open("chars.txt", ifstream::binary | ifstream::in);
	assert(!stats.fail());
	while ( !stats.eof() )
	{
		stats.get(thischar);	// Unigram
			if (stats.eof()) break;
		stats >> x;		// Frequency
			if (stats.eof()) break;
		stats >> y;		// Natural log of frequency
			if (stats.eof()) break;
		stats.get(extra);	// Ignore a newline

		unigramfreqs.hash(thischar, x);
	}
	stats.close();
	cout << "done.\n";
	cout.flush();

	/*
	 *	Read in bigram frequencies
	 */
	cout << "Reading in bigram frequencies . . . ";
	cout.flush();
	stats.open("digraphs.txt", ifstream::binary | ifstream::in);
	assert(!stats.fail());
	while ( !stats.eof() )
	{
		stats.get(thischar);	// First letter of bigram
			if (stats.eof()) break;
		stats.get(nextchar);	// Next letter of bigram
			if (stats.eof()) break;
		stats >> x;		// Frequency
			if (stats.eof()) break;
		stats >> y;		// Natural log of frequency
			if (stats.eof()) break;
		stats.get(extra);	// Ignore a newline

		bigramfreqs.hash(thischar, nextchar, x);
		bigramlogs.hash(thischar, nextchar, y);
	}
	stats.close();
	cout << "done.\n";
	cout.flush();

	look = new Checker(&bigramlogs);

	/*
	 *	Read in trigram frequencies
	 */
	cout << "Reading in trigram frequencies . . . ";
	cout.flush();
	stats.open("trigraphs.txt", ifstream::binary | ifstream::in);
	assert(!stats.fail());
	while ( !stats.eof() )
	{
		stats.get(lastchar);	// First letter of trigram
			if (stats.eof()) break;
		stats.get(thischar);	// Next letter of trigram
			if (stats.eof()) break;
		stats.get(nextchar);	// Last letter of trigram
			if (stats.eof()) break;
		stats >> x;		// Frequency
			if (stats.eof()) break;
		stats >> y;		// Natural log of frequency
			if (stats.eof()) break;
		stats.get(extra);	// Ignore a newline

		trigramfreqs.hash(lastchar, thischar, nextchar, x);
	}
	stats.close();
	cout << "done.\n";
	cout.flush();

	/*
	 *	Read in ciphertext (truncate to MAXLENGTH characters).
	 *	Remove null terminators ('\0') in embedded in file, then
	 *	append one to the end of the string.
	 *
	 *	Simultaneously count unigrams, bigrams, and trigrams
	 *	in ciphertext, then put them into hash tables.
	 */
	cout << "Reading in ciphertext . . . ";
	cout.flush();
	ciphertext.open("ciphertext.txt", ifstream::binary | ifstream::in);
	assert(!ciphertext.fail());
	length = 0;
	while ( !ciphertext.eof() )
	{
		ciphertext.get(thischar);
		if ( (thischar != '\0') && !ciphertext.eof() )  // Don't print nulls or EOF
		{
			cipher[length] = thischar;
			uni[thischar]++;
			if (length >= 1) bi[cipher[length-1]][thischar]++;
			if (length >= 2) tri[cipher[length-2]][cipher[length-1]][thischar]++;
			length++;
		}
		if (length > MAXLENGTH) break;
	}
	cipher[length] = '\0';
	ciphertext.close();
	cout << "done.\n\n";
	cout.flush();

	cout << "Compiling ciphertext heuristic data . . . ";
	cout.flush();
	for (x = 0; x < 256; x++)
	{
		obunigrams.hash(x, uni[x]);
		for (y = 0; y < 256; y++)
		{
			obbigrams.hash(x, y, z, bi[x][y]);
			for (z = 0; z < 256; z++)
			{
				obtrigrams.hash(x, y, z, tri[x][y][z]);
			}
		}
	}
	cout << "done.\n\n";
	cout.flush();

	/*
	 *	Perform actual cryptanalysis using subordinate functions
	 */
	cout << "Trying 8-bit XOR decryption . . . ";	// Try XOR decryption
	cout.flush();
	key = xor(cipher, plain, 2.0, false, true, true);
	if (key->Type())
	{
		cout << "EUREKA!\n";
		return 0;
	}
	delete key;
	cout << "failed.\n";
	cout.flush();

	cout << "Trying shift decryption . . . ";	// Try shift decryption
	cout.flush();
	key = shift(cipher, plain, 2.0, false, true, true);
	if (key->Type())
	{
		cout << "EUREKA!\n";
		return 0;
	}
	delete key;
	cout << "failed.\n";
	cout.flush();

	cout << "Finding best substitution decryption . . . ";
	cout.flush();	// Try substitution decryption
	key = substitution(cipher, plain, 60.0, false, true, false);
					// If XOR and shift failed, assume substitution
	if (key->Type())
	{
		cout << "EUREKA!\n";
		return 0;
	}
	delete key;
	cout << "failed.\n";
	cout.flush();

	findings.close();

	return 0;
}

Keys* xor(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter)
{
	Chain *data;
	Keys *key;		// Key being considered
	double score;	// Score of key being considered
	unsigned char highest;	// Highest occurring letter
	unsigned char highexpect;	// Letter expected to be highest occurring
	int amount;	// Occurrence rate of highest occurring
	int x;
	clock_t abort = clock() + (clock_t)(sec * CLOCKS_PER_SEC);	// Time out limit
	Keys *nextkey;		// Next key being considered
	double nextscore;	// Score of next key being considered
	unsigned char lesshigh;	// Character occurring less than highest
	int lessamount;	// Occurrence rate below amount
	int fails = 0;	// Number of loops without an improvement in score

	/*
	 *	Determine which letter should most commonly occur
	 */
	amount = 0;
	while ( (data = unigramfreqs.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highexpect = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Determine most commonly occurring letter
	 */
	amount = 0;
	while ( (data = obunigrams.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highest = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Try the key that maps the most commonly occurring letter
	 *	to the letter that should most commonly occur
	 */
	key = new Keys(1, (highest ^ highexpect));
	for (x = 0; x <= length; x++)
	{
		pt[x] = ct[x];
	}
	key->decrypt(pt);
	score = grade(pt, key);

	/*
	 *	Check next-most-common letters to see if a better
	 *	key can be found
	 */
	while (true)
	{
		if (sec && (abort < clock()) ) break;		// If timelimit has expired, end
		if (quick && (score >= THRESHOLD) ) break;	// If quick flag is set, and an
													// acceptable score was found, end
		amount = obunigrams.find(highest);
		lessamount = 0;

		while ( (data = obunigrams.dump()) != NULL)
		{
			if ( (data->value > lessamount) && ( (data->value < amount)
				|| (data->value == amount && ((char)(data->unhashed % 256) > highest)) ) )
			{
				lesshigh = (char)(data->unhashed % 256);
				lessamount = data->value;
			}
		}

		nextkey = new Keys(1, (lesshigh ^ highexpect));
		for (x = 0; x <= length; x++)
		{
			pt[x] = ct[x];
		}
		nextkey->decrypt(pt);
		nextscore = grade(pt, nextkey);

		highest = lesshigh;
		if (nextscore >= score)
		{
			delete key;
			key = nextkey;
			score = nextscore;
			fails = 0;
		}
		else
		{
			delete nextkey;
			fails++;
			if ((score >= THRESHOLD) && (fails >= 5)) break;
						// End if the last five keys tried didn't give better results
		}
	}
		
	/*
	 *	If best key found didn't meet threshold, discard it; if the
	 *	out flag is set, print score to file
	 */
	if ( (score < THRESHOLD) && (filter) )
	{
		if (out)
		{
			findings << "Failed to decrypt as 8-bit XOR (Score: " << score << ")\n\n";
		}
		delete key;
		Keys *zero = new Keys(0);
		return zero;
	}
	/*
	 *	If bets key found met threshold, return it; if the out flag is
	 *	set, print decryption to file
	 */
	else
	{
		if (out)
		{
			for (x = 0; x <= length; x++)
			{
				pt[x] = ct[x];
			}
			key->decrypt(pt);

			findings << "Decrypted as 8-bit XOR!" << endl;
			findings << "Key:   " << (int)key->Ckey() << endl;
			findings << "Score: " << score << endl;
			findings << "Decryption:" << endl << endl;
			findings << pt << endl << endl;
		}
		return key;
	}
}

Keys* shift(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter)
{
	Chain *data;
	Keys *key;		// Key being considered
	double score;	// Score of key being considered
	unsigned char highest;	// Highest occurring letter
	unsigned char highexpect;	// Letter expected to be highest occurring
	int amount;	// Occurrence rate of highest occurring
	int x;
	clock_t abort = clock() + (clock_t)(sec * CLOCKS_PER_SEC);	// Time out limit
	Keys *nextkey;		// Next key being considered
	double nextscore;	// Score of next key being considered
	unsigned char lesshigh;	// Character occurring less than highest
	int lessamount;	// Occurrence rate below amount
	int fails = 0;	// Number of loops without an improvement in score

	/*
	 *	Determine which letter should most commonly occur
	 */
	amount = 0;
	while ( (data = unigramfreqs.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highexpect = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Determine most commonly occurring letter
	 */
	amount = 0;
	while ( (data = obunigrams.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highest = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Try the key that maps the most commonly occurring letter
	 *	to the letter that should most commonly occur
	 */
	x = highest - highexpect;
	if (x < 0) x += 256;
	key = new Keys(2, x);
	for (x = 0; x <= length; x++)
	{
		pt[x] = ct[x];
	}
	key->decrypt(pt);
	score = grade(pt, key);

	/*
	 *	Check next-most-common letters to see if a better
	 *	key can be found
	 */
	while (true)
	{
		if (sec && (abort < clock()) ) break;		// If timelimit has expired, end
		if (quick && (score >= THRESHOLD) ) break;	// If quick flag is set, and an
													// acceptable score was found, end
		amount = obunigrams.find(highest);
		lessamount = 0;

		while ( (data = obunigrams.dump()) != NULL)
		{
			if ( (data->value > lessamount) && ( (data->value < amount)
				|| (data->value == amount && ((char)(data->unhashed % 256) > highest)) ) )
			{
				lesshigh = (char)(data->unhashed % 256);
				lessamount = data->value;
			}
		}

		x = lesshigh - highexpect;
		if (x < 0) x += 256;
		nextkey = new Keys(2, x);
		for (x = 0; x <= length; x++)
		{
			pt[x] = ct[x];
		}
		nextkey->decrypt(pt);
		nextscore = grade(pt, nextkey);

		highest = lesshigh;
		if (nextscore >= score)
		{
			delete key;
			key = nextkey;
			score = nextscore;
			fails = 0;
		}
		else
		{
			delete nextkey;
			fails++;
			if ((score >= THRESHOLD) && (fails >= 5)) break;
						// End if the last five keys tried didn't give better results
		}
	}
		
	/*
	 *	If best key found didn't meet threshold, discard it; if the
	 *	out flag is set, print score to file
	 */
	if ( (score < THRESHOLD) && (filter) )
	{
		if (out)
		{
			findings << "Failed to decrypt as 8-bit shift (Score: " << score << ")\n\n";
		}
		delete key;
		Keys *zero = new Keys(0);
		return zero;
	}
	/*
	 *	If bets key found met threshold, return it; if the out flag is
	 *	set, print decryption to file
	 */
	else
	{
		if (out)
		{
			for (x = 0; x <= length; x++)
			{
				pt[x] = ct[x];
			}
			key->decrypt(pt);

			findings << "Decrypted as shift!" << endl;
			findings << "Key:   " << (int)key->Ckey() << endl;
			findings << "Score: " << score << endl;
			findings << "Decryption:" << endl << endl;
			findings << pt << endl << endl;
		}
		return key;
	}
}

Keys* substitution(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter)
{
	Keys *key;		// Key being considered
	double score;	// Score of key being considered
	int x;			// Loop variable
	int i, j;		// Letter switch variables
	clock_t abort = clock() + (clock_t)(sec * CLOCKS_PER_SEC);	// Time out limit
	unsigned char map[256];	// Key map
	int nothere[256];	// Whether index char occurs in ciphertext
	Checker *local;			// Compares possible swaps with current key
	Checker *global = new Checker(&bigramlogs);	// Remembers est overall key

	/*
	 *	Fill in ishere[]
	 */
	for (x = 0; x < 256; x++)
	{
		if (obunigrams.find(x)) nothere[x] = false;
		else nothere[x] = true;
	}

	/*
	 *	Use hill-climbing technique with random starting point to
	 *	narrow in on best key
	 */
	do
	{
		cout << "X ";
		cout.flush();

		local = new Checker(&bigramlogs);
		/*
		 *	Choose a random key map
		 */
		for (x = 0; x < 256; x++) map[x] = x;
		for (i = 0; i < 256; i++)
		{
			j = random(0, 255);
			x = map[i];
			map[i] = map[j];
			map[j] = x;
		}

		/*
		 *	Score random key map
		 */
		key = new Keys(3, map);
		for (x = 0; x <= length; x++)
		{
			pt[x] = ct[x];
		}
		key->decrypt(pt);
		score = local->thorough(pt, key);
		delete key;

		while (true)	// Until a local maximum
		{
			/*
			 *	Try every possible character swap, score them
			 */
			for (i = 0; i < 256; i++)
			{
				if (map[i] == '\0') continue;	// Don't use null
				for (j = i+1; j < 256; j++)
				{
					if (nothere[i] && nothere[j]) continue;
								// Don't swap two nonexistant letters
					if (map[j] == '\0') continue;	// Don't use null

					if (map[i] == map[j])
					{
						cout << "SYSERR: Duplication detected in key:\n" << map;
					}
					
					x = map[i];
					map[i] = map[j];
					map[j] = x;

					key = new Keys(3, map);
					for (x = 0; x <= length; x++)
					{
						pt[x] = ct[x];
					}
					key->decrypt(pt);
					local->thorough(pt, key);
					delete key;

					x = map[i];
					map[i] = map[j];
					map[j] = x;
				}
				if ( (i % 50) == 0)
				{
					map;
					cout << ". ";
					cout.flush();
				}
			}

			cout << "\n   +" << (local->Best().score - score) << " ";
			cout.flush();

			if (score == local->Best().score) break;	// If no swap was an improvement, end
			score = local->Best().score;
			const unsigned char *keymap = local->Best().key->Skey();
			for (x = 0; x < 256; x++) map[x] = keymap[x];	// Take best swap

			/*
			 *	Restore accurate plaintext to pt
			 */
			key = new Keys(3, map);
			for (x = 0; x <= length; x++)
			{
				pt[x] = ct[x];
			}
			key->decrypt(pt);
			delete key;
		}

		/*
		 *	When no swap is good, remember the key and its score
		 */
		for (x = 0; x <= length; x++)
		{
			pt[x] = ct[x];
		}
		local->Best().key->decrypt(pt);
		score = global->check(pt, local->Best().key);

		delete local;
		if (quick && (score >= THRESHOLD) ) break;	// If quick flag is set, and an
													// acceptable score was found, end
	} while (!(sec && (clock() > abort)));			// If timelimit has expired, end

	/*
	 *	Remove from the key anything that is a blind guess
	 */
	const unsigned char *keymap = global->Best().key->Skey();
	for (x = 0; x < 256; x++)
	{
		if (!nothere[x]) map[x] = keymap[x];
		else map[x] = 0;
	}

	/*
	 *	Process the best key, whatever it was
	 */
	key = new Keys(3, map);
	for (x = 0; x <= length; x++)
	{
		pt[x] = ct[x];
	}
	key->decrypt(pt);
	score = grade(pt, key);
		
	/*
	 *	If best key found didn't meet threshold, discard it; if the
	 *	out flag is set, print score to file
	 */
	if ( (score < THRESHOLD) && (filter) )
	{
		if (out)
		{
			findings << "Failed to decrypt as substitution (Score: " << score << ")\n\n";
		}
		delete key;
		Keys *zero = new Keys(0);
		return zero;
	}
	/*
	 *	If bets key found met threshold, return it; if the out flag is
	 *	set, print decryption to file
	 */
	else
	{
		if (out)
		{
			const unsigned char *kmap = key->Skey();

			findings << "Decrypted as substitution!" << endl;
			findings << "Key:   ";
			for (x = 0; x < 256; x++)
			{
				if (!kmap[x]) continue;

				if ((x >= 33) && (x <= 126))	// Printable character
				{
					findings << (char)x;	// Output ciphertext char
				}
				else
				{
					findings << '\0';
				}

				findings << " (";

				if (x < 100) findings << ' ';	// Add spaces if less than 3 digits
				findings << x;	// Output ASCII
				if (x < 10) findings << ' ';

				findings << ") = ";

				if ((kmap[x] >= 33) && (kmap[x] <= 126))	// Printable character
				{
					findings << (char)(kmap[x]);	// Output Decryption
				}
				else
				{
					findings << '\0';
				}

				findings << " (";

				if (kmap[x] < 100) findings << ' ';	// Add spaces if less than 3 digits
				findings << (int)(kmap[x]);	// Output ASCII
				if (kmap[x] < 10) findings << ' ';

				findings << ")\n       ";
			}
			findings << endl;
			findings << "Score: " << score << endl;
			findings << "Decryption:" << endl << endl;
			findings << pt << endl << endl;
		}
		return key;
	}
}

Keys* vigenere(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter)	// Not yet implemented
{
	Keys *key = new Keys(0);
	return key;
}

double grade(unsigned char *plaintext, Keys *key)	// Grades with the Checker
{													// class, returns result
	if (forget)
	{
		return look->score(plaintext, key);
	}
	else
	{
		return look->check(plaintext, key);
	}
}

int complete(unsigned char *map)	// Returns true if map has a decryption for
{									// every letter in obunigrams
	Chain *data;
	while ( (data = obunigrams.dump()) != NULL)
	{
		if (map[data->unhashed] == 0) return false;
	}
	return true;
}

void alter(unsigned char a, unsigned char b, unsigned char *text)	// Exchange two
													// letters in a text string
{
	for (int x = 0; text[x] != '\0'; x++)
	{
		if (text[x] == a) text[x] = b;
		else if (text[x] == b) text[x] = a;
	}
}

unsigned char firstc(unsigned long x)	// Returns first 8 bits of x
{
	return (unsigned char)(x & FIRSTCHAR);
}

unsigned char secondc(unsigned long x)	// Returns second 8 bits of x
{
	return (unsigned char)((x & SECONDCHAR) / 256);
}

unsigned char thirdc(unsigned long x)	// Returns third 8 bits of x
{
	return (unsigned char)((x & THIRDCHAR) / 65536);	// 65536 = (256 * 256)
}

int random(int low, int high)	// Returns a random integer between low and high
{
	return (rand() % (high - low + 1)) + low;
}