/*
 *	Decrypter Demo
 *	By Jeremy Lennert
 *	March 10, 2001
 *	v. 1.0
 *
 *	main.cpp
 *
 *	Repeatedly encrypt and decrypt various texts
 *	to demonstrate cryptanalysis.
 */

// COMPILER DIRECTIVES
#include <assert.h>
#include <fstream.h>
#include <stdlib.h>
#include <time.h>
#include "Keys.h"
#include "FreqTable.h"
#include "Checker.h"

// FUNCTION PROTOTYPES
int encrypt();
int decrypt();

double grade(unsigned char *plaintext, Keys *key);	// Grades with the Checker
													// class, returns result
int complete(unsigned char *map);	// Returns true if map has a decryption for
									// every letter in obunigrams
void alter(unsigned char a, unsigned char b, unsigned char *text);	// Exchange
													// two letters in a text string
unsigned char firstc(unsigned long x);	// Returns first 8 bits of x
unsigned char secondc(unsigned long x);	// Returns second 8 bits of x
unsigned char thirdc(unsigned long x);	// Returns third 8 bits of x
int random(int low, int high);	// Returns a random integer between low and high
void sleep(double sec);		// Waits for specified time

/*
 *	- Cryptanalysis Functions -
 *
 *	unsigned char *ct	- pointer to the ciphertext
 *	unsigned chat *pt	- pointer to an array that may be overwritten
 *						  to store candidate plaintext
 *	float sec			- length of time after which to stop searching for better keys
 *	int quick			- if set, returns first key with a score that meets THRESHOLD
 *	int out				- if set, prints results to file "findings"
 *	int filter			- if set, and best key fails to meet THRESHOLD, returns a zero key
 */
Keys* xor(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter);
Keys* shift(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter);

// GLOBAL VARIABLES
#define MAXLENGTH (5000)		// Maximum length of sample text
#define THRESHOLD (-45000.0)	// Minimum score considered a successful decryption
#define FIRSTCHAR (0x0000FF)	// Binary & yields first 8 bits
#define SECONDCHAR (0x00FF00)	// Binary & yields second 8 bits
#define THIRDCHAR (0xFF0000)	// Binary & yields third 8 bits

#define SHOWLENGTH (250)	// Number of characters in text printed to screen

FreqTable unigramfreqs;	// expected unigram frequencies (x10000)
FreqTable bigramfreqs;	// expected bigram frequencies (x10000)
FreqTable bigramlogs;	// natural logs of expected bigram frequencies (x10000)
FreqTable trigramfreqs;	// expected trigram frequencies (x10000)

FreqTable obunigrams;	// observed unigram counts
FreqTable obbigrams;	// observed bigram counts
FreqTable obtrigrams;	// observed trigram counts

Checker *look;		// Checker that examines results
int length;		// Total length of ciphertext, in bytes
unsigned char cipher[MAXLENGTH+1];	// stores ciphertext
unsigned char plain[MAXLENGTH+1];	// stores candidate plaintext
ofstream findings;	// Results of attempt to decrypt

int uni[256];	// observed unigram counts
int bi[256][256];	// observed bigram counts
int tri[256][256][256];	// observed trigram counts

int forget = 0;	// If set, don't let the Checker class compare scores
				// against the best its seen



int main(int argc, char **argv)
{
	bool loop = false;

	/*
	while (--argc > 0) {
		argv++;
		if (argv[0][0] != '-') break;	// not a switch
		switch (argv[0][1]) {
		case 'l':
			loop = true;
			break;
		default:
			cout << "\nIllegal switch: " << argv[0];
			break;
		}
	}
	*/
	loop = (argc > 1);	// XXX

	long int t;
	srand(time(&t));	// Seed pseudo-random number generator with time
	
	do
	{
		if (encrypt())
		{
			cout << "\nERROR IN ENCRYPTION!\n";
			cout.flush();
			continue;
		}
		if (decrypt())
		{
			cout << "\nERROR IN DECRYPTION!\n";
			cout.flush();
			continue;
		}
		sleep(10.0);
	} while (loop);

	return 0;
}

int encrypt()
{
	ifstream plainfile;
	ofstream cipherfile;
	unsigned char thischar;
	int type;
	unsigned char ckey;
	int x;	// loop variable

	switch (random(0, 10))
	{
	case 1:
		plainfile.open("plain1.txt", ifstream::binary | ifstream::in);
		break;
	case 2:
		plainfile.open("plain2.txt", ifstream::binary | ifstream::in);
		break;
	case 3:
		plainfile.open("plain3.txt", ifstream::binary | ifstream::in);
		break;
	case 4:
		plainfile.open("plain4.txt", ifstream::binary | ifstream::in);
		break;
	case 5:
		plainfile.open("plain5.txt", ifstream::binary | ifstream::in);
		break;
	case 6:
		plainfile.open("plain6.txt", ifstream::binary | ifstream::in);
		break;
	case 7:
		plainfile.open("plain7.txt", ifstream::binary | ifstream::in);
		break;
	case 8:
		plainfile.open("plain8.txt", ifstream::binary | ifstream::in);
		break;
	case 9:
		plainfile.open("plain9.txt", ifstream::binary | ifstream::in);
		break;
	case 10:
		plainfile.open("plain10.txt", ifstream::binary | ifstream::in);
		break;
	default:
		plainfile.open("plain0.txt", ifstream::binary | ifstream::in);
		break;
	}
	if (plainfile.fail()) return 1;

	cipherfile.open("ciphertext.txt", ifstream::binary | ofstream::out);
	if (cipherfile.fail())
	{
		plainfile.close();
		return 1;
	}

	type = random(1, 2);
	ckey = rand() % 256;

	length = 0;
	cout << "\n\n\n\n\nNow encrypting a plaintext that begins:\n\n";
	while (true)
	{
		thischar = plainfile.get();
		if (plainfile.eof()) break;
		cipher[length] = thischar;
		if (length <= SHOWLENGTH)
		{
			if ((length % 50) == 0) cout << "\n>  ";
			cout << thischar;
			if (thischar == '\n') cout << ">  ";
		}
		length++;
		if (length > MAXLENGTH) break;
	}
	cipher[length] = '\0';
	cout.flush();
	sleep(5.0);

	switch(type)
	{
	case 1:	// XOR
		cout << "\n\n\nBy an eight-bit XOR cipher with key: " << (int)ckey << "\n";
		cout << "The encrypted ciphertext begins:\n\n";
		cout.flush();
		for (x = 0; x < length; x++)
		{
			thischar = cipher[x] ^ ckey;
			cipherfile << thischar;
			if (x <= SHOWLENGTH)
			{
				if ((x % 50) == 0) cout << "\n>  ";
				cout << thischar;
				if (thischar == '\n') cout << ">  ";
			}
		}
		break;
	case 2:	// shift
		cout << "\n\nBy a shift cipher with key: " << (int)ckey << "\n";
		cout << "The encrypted ciphertext begins:\n\n";
		cout.flush();
		for (x = 0; x < length; x++)
		{
			thischar = (cipher[x] + ckey) % 256;
			cipherfile << thischar;
			if (x <= SHOWLENGTH)
			{
				if ((x % 50) == 0) cout << "\n>  ";
				cout << thischar;
				if (thischar == '\n') cout << ">  ";
			}
		}
		break;
	}
	cout.flush();
	sleep(5.0);
	cout << "\n\n\nACTIVATING AUTOMATED DECRYPTION PROGRAM:\n\n";
	cout.flush();

	plainfile.close();
	cipherfile.close();

	return 0;
}

int decrypt()
{
	ifstream ciphertext;// ciphertext
	ifstream stats;		// expected bigram frequencies, natural logs
	unsigned char lastchar, thischar, nextchar, extra;
	int x, y, z;
	Keys* key;
	
	findings.open("findings.txt");

	/*
	 *	Read in unigram frequencies
	 */
	cout << "Reading in unigram frequencies . . . ";
	cout.flush();
	stats.open("chars.txt", ifstream::binary | ifstream::in);
	assert(!stats.fail());
	while ( !stats.eof() )
	{
		stats.get(thischar);	// Unigram
			if (stats.eof()) break;
		stats >> x;		// Frequency
			if (stats.eof()) break;
		stats >> y;		// Natural log of frequency
			if (stats.eof()) break;
		stats.get(extra);	// Ignore a newline

		unigramfreqs.hash(thischar, x);
	}
	stats.close();
	cout << "done.\n";
	cout.flush();

	/*
	 *	Read in bigram frequencies
	 */
	cout << "Reading in bigram frequencies . . . ";
	cout.flush();
	stats.open("digraphs.txt", ifstream::binary | ifstream::in);
	assert(!stats.fail());
	while ( !stats.eof() )
	{
		stats.get(thischar);	// First letter of bigram
			if (stats.eof()) break;
		stats.get(nextchar);	// Next letter of bigram
			if (stats.eof()) break;
		stats >> x;		// Frequency
			if (stats.eof()) break;
		stats >> y;		// Natural log of frequency
			if (stats.eof()) break;
		stats.get(extra);	// Ignore a newline

		bigramfreqs.hash(thischar, nextchar, x);
		bigramlogs.hash(thischar, nextchar, y);
	}
	stats.close();
	cout << "done.\n";
	cout.flush();

	look = new Checker(&bigramlogs);

	/*
	 *	Read in trigram frequencies
	 */
	cout << "Reading in trigram frequencies . . . ";
	cout.flush();
	stats.open("trigraphs.txt", ifstream::binary | ifstream::in);
	assert(!stats.fail());
	while ( !stats.eof() )
	{
		stats.get(lastchar);	// First letter of trigram
			if (stats.eof()) break;
		stats.get(thischar);	// Next letter of trigram
			if (stats.eof()) break;
		stats.get(nextchar);	// Last letter of trigram
			if (stats.eof()) break;
		stats >> x;		// Frequency
			if (stats.eof()) break;
		stats >> y;		// Natural log of frequency
			if (stats.eof()) break;
		stats.get(extra);	// Ignore a newline

		trigramfreqs.hash(lastchar, thischar, nextchar, x);
	}
	stats.close();
	cout << "done.\n";
	cout.flush();

	/*
	 *	Read in ciphertext (truncate to MAXLENGTH characters).
	 *	Remove null terminators ('\0') in embedded in file, then
	 *	append one to the end of the string.
	 *
	 *	Simultaneously count unigrams, bigrams, and trigrams
	 *	in ciphertext, then put them into hash tables.
	 */
	cout << "Reading in ciphertext . . . ";
	cout.flush();
	ciphertext.open("ciphertext.txt", ifstream::binary | ifstream::in);
	assert(!ciphertext.fail());
	length = 0;
	while ( !ciphertext.eof() )
	{
		ciphertext.get(thischar);
		if ( (thischar != '\0') && !ciphertext.eof() )  // Don't print nulls or EOF
		{
			cipher[length] = thischar;
			uni[thischar]++;
			if (length >= 1) bi[cipher[length-1]][thischar]++;
			if (length >= 2) tri[cipher[length-2]][cipher[length-1]][thischar]++;
			length++;
		}
		if (length > MAXLENGTH) break;
	}
	cipher[length] = '\0';
	ciphertext.close();
	cout << "done.\n\n";
	cout.flush();

	cout << "Compiling ciphertext heuristic data . . . ";
	cout.flush();
	for (x = 0; x < 256; x++)
	{
		obunigrams.hash(x, uni[x]);
		for (y = 0; y < 256; y++)
		{
			obbigrams.hash(x, y, z, bi[x][y]);
			for (z = 0; z < 256; z++)
			{
				obtrigrams.hash(x, y, z, tri[x][y][z]);
			}
		}
	}
	cout << "done.\n\n";
	cout.flush();

	/*
	 *	Perform actual cryptanalysis using subordinate functions
	 */
	cout << "Trying 8-bit XOR decryption . . . ";	// Try XOR decryption
	cout.flush();
	key = xor(cipher, plain, 5.0, false, true, true);
	if (key->Type())
	{
		delete key;
		cout.flush();
		return 0;
	}
	delete key;
	cout.flush();

	cout << "Trying shift decryption . . . ";	// Try shift decryption
	cout.flush();
	key = shift(cipher, plain, 5.0, false, true, true);
	if (key->Type())
	{
		delete key;
		cout.flush();
		return 0;
	}
	delete key;
	cout << "\nAUTOMATED DECRYPTION UNSUCCESSFUL\n\n";
	cout.flush();

	findings.close();

	return 0;
}

Keys* xor(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter)
{
	Chain *data;
	Keys *key;		// Key being considered
	double score;	// Score of key being considered
	unsigned char highest;	// Highest occurring letter
	unsigned char highexpect;	// Letter expected to be highest occurring
	int amount;	// Occurrence rate of highest occurring
	int x;
	clock_t abort = clock() + (clock_t)(sec * CLOCKS_PER_SEC);	// Time out limit
	Keys *nextkey;		// Next key being considered
	double nextscore;	// Score of next key being considered
	unsigned char lesshigh;	// Character occurring less than highest
	int lessamount;	// Occurrence rate below amount
	int fails = 0;	// Number of loops without an improvement in score

	/*
	 *	Determine which letter should most commonly occur
	 */
	amount = 0;
	while ( (data = unigramfreqs.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highexpect = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Determine most commonly occurring letter
	 */
	amount = 0;
	while ( (data = obunigrams.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highest = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Try the key that maps the most commonly occurring letter
	 *	to the letter that should most commonly occur
	 */
	key = new Keys(1, (highest ^ highexpect));
	for (x = 0; x <= length; x++)
	{
		pt[x] = ct[x];
	}
	key->decrypt(pt);
	score = grade(pt, key);

	/*
	 *	Check next-most-common letters to see if a better
	 *	key can be found
	 */
	while (true)
	{
		if (sec && (abort < clock()) ) break;		// If timelimit has expired, end
		if (quick && (score >= THRESHOLD) ) break;	// If quick flag is set, and an
													// acceptable score was found, end
		amount = obunigrams.find(highest);
		lessamount = 0;

		while ( (data = obunigrams.dump()) != NULL)
		{
			if ( (data->value > lessamount) && ( (data->value < amount)
				|| (data->value == amount && ((char)(data->unhashed % 256) > highest)) ) )
			{
				lesshigh = (char)(data->unhashed % 256);
				lessamount = data->value;
			}
		}

		nextkey = new Keys(1, (lesshigh ^ highexpect));
		for (x = 0; x <= length; x++)
		{
			pt[x] = ct[x];
		}
		nextkey->decrypt(pt);
		nextscore = grade(pt, nextkey);

		highest = lesshigh;
		if (nextscore >= score)
		{
			delete key;
			key = nextkey;
			score = nextscore;
			fails = 0;
		}
		else
		{
			delete nextkey;
			fails++;
			if ((score >= THRESHOLD) && (fails >= 5)) break;
						// End if the last five keys tried didn't give better results
		}
	}
		
	/*
	 *	If best key found didn't meet threshold, discard it; if the
	 *	out flag is set, print score to file
	 */
	if ( (score < THRESHOLD) && (filter) )
	{
		if (out)
		{
			cout << "failed (Score: " << score << ")\n";
		}
		delete key;
		Keys *zero = new Keys(0);
		return zero;
	}
	/*
	 *	If bets key found met threshold, return it; if the out flag is
	 *	set, print decryption to file
	 */
	else
	{
		if (out)
		{
			for (x = 0; x <= length; x++)
			{
				pt[x] = ct[x];
			}
			key->decrypt(pt);

			cout << "EUREKA!\n\n";
			cout << "Decrypted as eight-bit XOR!\n";
			cout << "Key:   " << (int)key->Ckey() << endl;
			cout << "Score: " << score << endl;
			cout << "Decryption begins:" << endl << endl;

			for (x = 0; x <= SHOWLENGTH; x++)
			{
				if (x > length) break;
				if ((x % 50) == 0) cout << "\n>  ";
				cout << pt[x];
				if (pt[x] == '\n') cout << ">  ";
			}

			cout << "\n\n";
		}
		return key;
	}
}

Keys* shift(unsigned char *ct, unsigned char *pt, float sec, int quick, int out, int filter)
{
	Chain *data;
	Keys *key;		// Key being considered
	double score;	// Score of key being considered
	unsigned char highest;	// Highest occurring letter
	unsigned char highexpect;	// Letter expected to be highest occurring
	int amount;	// Occurrence rate of highest occurring
	int x;
	clock_t abort = clock() + (clock_t)(sec * CLOCKS_PER_SEC);	// Time out limit
	Keys *nextkey;		// Next key being considered
	double nextscore;	// Score of next key being considered
	unsigned char lesshigh;	// Character occurring less than highest
	int lessamount;	// Occurrence rate below amount
	int fails = 0;	// Number of loops without an improvement in score

	/*
	 *	Determine which letter should most commonly occur
	 */
	amount = 0;
	while ( (data = unigramfreqs.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highexpect = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Determine most commonly occurring letter
	 */
	amount = 0;
	while ( (data = obunigrams.dump()) != NULL)
	{
		if (data->value > amount)
		{
			highest = (char)(data->unhashed % 256);
			amount = data->value;
		}
	}
	/*
	 *	Try the key that maps the most commonly occurring letter
	 *	to the letter that should most commonly occur
	 */
	x = highest - highexpect;
	if (x < 0) x += 256;
	key = new Keys(2, x);
	for (x = 0; x <= length; x++)
	{
		pt[x] = ct[x];
	}
	key->decrypt(pt);
	score = grade(pt, key);

	/*
	 *	Check next-most-common letters to see if a better
	 *	key can be found
	 */
	while (true)
	{
		if (sec && (abort < clock()) ) break;		// If timelimit has expired, end
		if (quick && (score >= THRESHOLD) ) break;	// If quick flag is set, and an
													// acceptable score was found, end
		amount = obunigrams.find(highest);
		lessamount = 0;

		while ( (data = obunigrams.dump()) != NULL)
		{
			if ( (data->value > lessamount) && ( (data->value < amount)
				|| (data->value == amount && ((char)(data->unhashed % 256) > highest)) ) )
			{
				lesshigh = (char)(data->unhashed % 256);
				lessamount = data->value;
			}
		}

		x = lesshigh - highexpect;
		if (x < 0) x += 256;
		nextkey = new Keys(2, x);
		for (x = 0; x <= length; x++)
		{
			pt[x] = ct[x];
		}
		nextkey->decrypt(pt);
		nextscore = grade(pt, nextkey);

		highest = lesshigh;
		if (nextscore >= score)
		{
			delete key;
			key = nextkey;
			score = nextscore;
			fails = 0;
		}
		else
		{
			delete nextkey;
			fails++;
			if ((score >= THRESHOLD) && (fails >= 5)) break;
						// End if the last five keys tried didn't give better results
		}
	}
		
	/*
	 *	If best key found didn't meet threshold, discard it; if the
	 *	out flag is set, print score to file
	 */
	if ( (score < THRESHOLD) && (filter) )
	{
		if (out)
		{
			cout << "failed (Score: " << score << ")\n";
		}
		delete key;
		Keys *zero = new Keys(0);
		return zero;
	}
	/*
	 *	If bets key found met threshold, return it; if the out flag is
	 *	set, print decryption to file
	 */
	else
	{
		if (out)
		{
			for (x = 0; x <= length; x++)
			{
				pt[x] = ct[x];
			}
			key->decrypt(pt);

			cout << "EUREKA!\n\n";
			cout << "Decrypted as shift!\n";
			cout << "Key:   " << (int)key->Ckey() << endl;
			cout << "Score: " << score << endl;
			cout << "Decryption begins:" << endl << endl;

			for (x = 0; x <= SHOWLENGTH; x++)
			{
				if (x > length) break;
				if ((x % 50) == 0) cout << "\n>  ";
				cout << pt[x];
				if (pt[x] == '\n') cout << ">  ";
			}

			cout << "\n\n";
		}
		return key;
	}
}

double grade(unsigned char *plaintext, Keys *key)	// Grades with the Checker
{													// class, returns result
	if (forget)
	{
		return look->score(plaintext, key);
	}
	else
	{
		return look->check(plaintext, key);
	}
}

int complete(unsigned char *map)	// Returns true if map has a decryption for
{									// every letter in obunigrams
	Chain *data;
	while ( (data = obunigrams.dump()) != NULL)
	{
		if (map[data->unhashed] == 0) return false;
	}
	return true;
}

void alter(unsigned char a, unsigned char b, unsigned char *text)	// Exchange two
													// letters in a text string
{
	for (int x = 0; text[x] != '\0'; x++)
	{
		if (text[x] == a) text[x] = b;
		else if (text[x] == b) text[x] = a;
	}
}

unsigned char firstc(unsigned long x)	// Returns first 8 bits of x
{
	return (unsigned char)(x & FIRSTCHAR);
}

unsigned char secondc(unsigned long x)	// Returns second 8 bits of x
{
	return (unsigned char)((x & SECONDCHAR) / 256);
}

unsigned char thirdc(unsigned long x)	// Returns third 8 bits of x
{
	return (unsigned char)((x & THIRDCHAR) / 65536);	// 65536 = (256 * 256)
}

int random(int low, int high)	// Returns a random integer between low and high
{
	return (rand() % (high - low + 1)) + low;
}

void sleep(double sec)		// Waits for specified time
{
	clock_t done = clock() + (clock_t)(sec * CLOCKS_PER_SEC);
	while (clock() < done);
}