#include <stdexcept>
#include <vector>
#include <string>
#include <pcap.h>
#include <ctype.h>  // isxdigit
#include <unistd.h> // getopt
#include <stdlib.h> // atoi,rand
#include <libgen.h> // basename

#include "Generator.h"
#include "util.h"

bool Generator::extractIV(const vector<byte>& packet, WeakIV& iv) {

	if (packet.size() < 29) // packet too small
		return false;
	
	iv.IV.clear();
	// push them backwards
	iv.IV.push_back(packet[24]);
	iv.IV.push_back(packet[25]);
	iv.IV.push_back(packet[26]);
	// assume the first byte of plaintext is 0xaa
	iv.Z = packet[28] ^ 0xaa;
	
	dp.printDebug("extractIV: IV=%s Z=%02x packet=%s\n", 
			dp.printVector("", iv.IV).c_str(), iv.Z, 
			dp.printVector("", packet).c_str());
 	
	return true;
}

CaptureGenerator::CaptureGenerator(const char* where, const char* mac, int keyid, bool live, int debug_level) {

	dp.setPrefix("[cap_gen] ");
	dp.setLevel(debug_level);

	assert(where);

	char errbuf[PCAP_ERRBUF_SIZE];

	if (live) 
		handle = pcap_open_live(where, BUFSIZ, 1, 0, errbuf);
	else
		handle = pcap_open_offline(where, errbuf);

	if (!handle)
		throw BadCaptureSource(errbuf);
	
	int type = pcap_datalink(handle);
	if (type != DLT_IEEE802_11)
		throw BadDataLinkType(pcap_datalink_val_to_description(type));

	if (mac && !checkMAC(mac))
		throw BadMACAddress(mac);

	if (keyid >3)
		throw BadKeyID(keyid);

	// build filter
	struct bpf_program filter;
	char keyID;

	switch (keyid) {
		case 0: keyID = '0'; break;
		case 1: keyID = '4'; break;
		case 2: keyID = '8'; break;
		default: keyID= 'c'; break;
	}
	
	char filter_text[] = "wlan host __:__:__:__:__:__ and wlan[0:2] & 0xff40 = 0x0840 and wlan[27] = 0x_0"; 

	if (mac && keyid>=0) {
		strncpy(filter_text+10, mac, 12+5); // mac -> filter string
		filter_text[strlen(filter_text)-2] = keyID; // keyid -> filter string
	} else if (mac) {
		strcpy(filter_text, "wlan host __:__:__:__:__:__ and wlan[0:2] & 0xff40 = 0x0840");
		strncpy(filter_text+10, mac, 12+5); // mac -> filter string
	} else if (keyid>=0) {
		strcpy(filter_text, "wlan[0:2] & 0xff40 = 0x0840 and wlan[27] = 0x_0");
		filter_text[strlen(filter_text)-2] = keyID; // keyid -> filter string
	} else { // data frames
		strcpy(filter_text, "wlan[0:2] & 0xff40 = 0x0840"); 
	}
	
	dp.printVerbose("using filter: %s\n", filter_text);

	assert(-1 != pcap_compile(handle, &filter, filter_text, 1, 0));
	assert(-1 != pcap_setfilter(handle, &filter));
}

bool CaptureGenerator::getWeakIV(WeakIV& iv) {
	struct pcap_pkthdr pcap_hdr;
	const u_char* data;

	data = pcap_next(handle, &pcap_hdr);
	if (!data) return false;

	if (pcap_hdr.caplen < 29) // packet too small
		return false;
	
	iv.IV.clear();
	// push them backwards
	iv.IV.push_back(data[24]);
	iv.IV.push_back(data[25]);
	iv.IV.push_back(data[26]);
	// assume the first byte of plaintext is 0xaa
	iv.Z = data[28] ^ 0xaa;

	if (dp.getLevel()>3) {
		vector<byte> packet;
		packet.resize(pcap_hdr.caplen);
		copy(data, data+pcap_hdr.caplen, packet.begin());

		dp.printDebug("getWeakIV: IV=%s Z=%02x packet=%s\n", 
			dp.printVector("", iv.IV).c_str(), iv.Z,
			dp.printVector("", packet).c_str());
	}
	
	return true;
}

bool CaptureGenerator::getPacket(vector<byte>& packet) {
	struct pcap_pkthdr pcap_hdr;
	const u_char* data;

	data = pcap_next(handle, &pcap_hdr);
	if (!data) return false;

	packet.resize(pcap_hdr.caplen);
	copy(data, data+pcap_hdr.caplen, packet.begin());

	dp.printDebug("getPacket: packet=%s\n",
			dp.printVector("", packet).c_str());

	return true;
}

CaptureGenerator::~CaptureGenerator() {
	assert(handle != NULL);

	pcap_close(handle);
}

// valid macs have format xx:xx:xx:xx:xx:xx where x is hex digit
bool CaptureGenerator::checkMAC(const char* mac) {
	if (strlen(mac) != 12+5)
		return false;

	for(int i=0; i<12+5; i++) {
		if ((i+1)%3 == 0 && mac[i] == ':')
			continue;

		if (isxdigit(mac[i]))
			continue;

		return false;
	}

	return true;
}

RandomGenerator::RandomGenerator(int len, int debug_level):I(3), l(len), N(256){
	dp.setPrefix("[rand_gen] ");
	dp.setLevel(debug_level);

	assert(len > 0);
	
	key.resize(len);
	for (unsigned i=0; i<len; i++) 
		key[i] = rand();

	dp.printVerbose("generated key: %s\n", dp.printVector("", key, ":").c_str());

	counter = 0;
}

RandomGenerator::RandomGenerator(const vector<byte>& k, int debug_level): 
	I(3), l(k.size()), N(256), key(k) {
		
	dp.setPrefix("[rand_gen] ");
	dp.setLevel(debug_level);

	dp.printVerbose("using key: %s\n", dp.printVector("", key, ":").c_str());

	counter = 0;

}

bool RandomGenerator::getWeakIV(WeakIV& iv) {
	iv.IV.clear();
	// push them backwards
	for(int i=0; i<I; i++)
		iv.IV.push_back(rand());
	// assume the first byte of plaintext is 0xaa
	iv.Z = genByte1(iv.IV);

	if (dp.getLevel()>3)
	  dp.printDebug("getWeakIV: %s Z=%02x\n", 
			  dp.printVector("IV=", iv.IV).c_str(), iv.Z);

	return true;	
}

byte RandomGenerator::genByte1(vector<byte> iv) {

	byte S[N];
	int i, j;
	
	iv.insert(iv.end(), key.begin(), key.end());

	// --- KSA start ---
	for(i=0; i<N; i++) 
		S[i] = i;
	j=0;

	for(i=0; i<N; i++) {
		j = j + S[i] + iv[i%iv.size()]; j%=N;
		swap(S[i], S[j]);
	}
	// --- KSA end ---
	
	// first output byte of PRGA
	return S[ ( S[1]+S[S[1]] )%N ];

}

bool RandomGenerator::getPacket(vector<byte>& packet) {
	static byte arp_req[] = {
		0x08, 0x41, 0x75, 0x00, 0x00, 0x40, 0x96, 0x53,
		0x11, 0xcc, 0x00, 0x0d, 0x28, 0x4d, 0xcb, 0xf1,
		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x50, 0x69,
		0xff, 0xff, 0xff, 0x00, // IV + keyid
		0xaa, 0xaa, 0x03, 0x00, 0x00, 0x00, 0x08, 0x06,
		0x00, 0x01, 0x08, 0x00, 0x06, 0x04, 0x00, 0x01,
		0x00, 0x0d, 0x28, 0x4d, 0xcb, 0xf1, 0xa9, 0xfe,
		0x7f, 0xb2, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
		0xa9, 0xfe, 0x7f, 0x01, 0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
		0xff, 0xff, 0xff, 0xff  // ICV
	};

	static byte arp_resp[] = {
		0x08, 0x42, 0x75, 0x00, 0x00, 0x0d, 0x28, 0x4d,
		0xcb, 0xf1, 0x00, 0x40, 0x96, 0x53, 0x11, 0xcc,
		0x00, 0x90, 0x96, 0x21, 0xe5, 0x17, 0xc0, 0x71,
		0xff, 0xff, 0xff, 0x00, // IV + keyid
		0xaa, 0xaa, 0x03, 0x00, 0x00, 0x00, 0x08, 0x06,
		0x00, 0x01, 0x08, 0x00, 0x06, 0x04, 0x00, 0x02,
		0x00, 0x90, 0x96, 0x21, 0xe5, 0x17, 0xa9, 0xfe,
		0x7f, 0x01, 0x00, 0x0d, 0x28, 0x4d, 0xcb, 0xf1,
		0xa9, 0xfe, 0x7f, 0xb2, 0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
		0xff, 0xff, 0xff, 0xff  // ICV
	};

	static byte ping_req[] = {
		0x08, 0x41, 0x75, 0x00, 0x00, 0x40, 0x96, 0x53,
		0x11, 0xcc, 0x00, 0x0d, 0x28, 0x4d, 0xcb, 0xf1,
		0x00, 0x90, 0x96, 0x21, 0xe5, 0x17, 0x60, 0x69,
		0xff, 0xff, 0xff, 0x00, // IV + keyid
		0xaa, 0xaa, 0x03, 0x00, 0x00, 0x00, 0x08, 0x00,
		0x45, 0x00, 0x00, 0x3c, 0x03, 0x20, 0x00, 0x00,
		0x80, 0x01, 0xe4, 0xf0, 0xa9, 0xfe, 0x7f, 0xb2,
		0xa9, 0xfe, 0x7f, 0x01, 0x08, 0x00, 0x37, 0x5c,
		0x02, 0x00, 0x14, 0x00, 0x61, 0x62, 0x63, 0x64,
		0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c,
		0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74,
		0x75, 0x76, 0x77, 0x61, 0x62, 0x63, 0x64, 0x65,
		0x66, 0x67, 0x68, 0x69,
		0xff, 0xff, 0xff, 0xff  // ICV
	};

	static byte ping_resp[] = {
		0x08, 0x42, 0x75, 0x00, 0x00, 0x0d, 0x28, 0x4d,
		0xcb, 0xf1, 0x00, 0x40, 0x96, 0x53, 0x11, 0xcc,
		0x00, 0x90, 0x96, 0x21, 0xe5, 0x17, 0xd0, 0x71,
		0xff, 0xff, 0xff, 0x00, // IV + keyid
		0xaa, 0xaa, 0x03, 0x00, 0x00, 0x00, 0x08, 0x00,
		0x45, 0x00, 0x00, 0x3c, 0xa9, 0xc3, 0x00, 0x00,
		0x40, 0x01, 0x7e, 0x4d, 0xa9, 0xfe, 0x7f, 0x01,
		0xa9, 0xfe, 0x7f, 0xb2, 0x00, 0x00, 0x3f, 0x5c,
		0x02, 0x00, 0x14, 0x00, 0x61, 0x62, 0x63, 0x64,
		0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c,
		0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74,
		0x75, 0x76, 0x77, 0x61, 0x62, 0x63, 0x64, 0x65,
		0x66, 0x67, 0x68, 0x69,
		0xff, 0xff, 0xff, 0xff  // ICV
	};

	byte* pack;
	int pack_len;

	if (counter%10 == 0) {
		pack = arp_req;
		pack_len = sizeof(arp_req);
	} else if (counter%10 == 1) {
		pack = arp_resp;
		pack_len = sizeof(arp_resp);
	} else if (counter%2 == 0) {
		pack = ping_req;
		pack_len = sizeof(ping_req);
	} else {
		pack = ping_resp;
		pack_len = sizeof(ping_resp);
	}

	packet.resize(pack_len);
	copy(pack, pack+pack_len, packet.begin());

	CRC32 crc;
	crc.Update((char*)pack+28, pack_len-4-28);
	crc.Finish();

	for(int i=0; i<4; i++) // put crc bytes in reverse order
		packet[pack_len-4+i] = crc[i];

	for(int i=0; i<I; i++)  // generate new IV
		packet[24+i] = rand();

	if (dp.getLevel()>3)
	  dp.printDebug("getPacket: before encryption: %s\n",
		dp.printVector("", packet, "").c_str());
	
	encrypt(packet);

	if (dp.getLevel()>3)
	  dp.printDebug("getPacket: after encryption: %s\n",
		dp.printVector("", packet, "").c_str());
	
	counter++;

	return true;
}

void RandomGenerator::encrypt(vector<byte>& packet) {
	vector<byte> rc4_key;

	// construct rc4 key
	for(int i=0; i<I; i++)
		rc4_key.push_back(packet[24+i]);

	for(int i=0; i<key.size(); i++)
		rc4_key.push_back(key[i]);
	
	// --- KSA start ---
	int i,j;
	byte S[N];

	for(i=0; i<N; i++) 
		S[i] = i;
	j=0;

	for(i=0; i<N; i++) {
		j = j + S[i] + rc4_key[i%rc4_key.size()]; j%=N;
		swap(S[i], S[j]);

  	  	if (dp.getLevel()>4)
		  dp.printDebug2("encrypt: KSA: i=%d j=%d K=%s S=[%s]\n",i,j,
			dp.printVector("", rc4_key).c_str(),
			dp.printVector("", S, N, " ").c_str());
	}
	// --- KSA end ---
	
	i=j=0;
	// PRGA + encrypt
	for(int k=28; k<packet.size(); k++) {
		i++;	 i%=N;
		j+=S[i]; j%=N;
		swap(S[i], S[j]);

		packet[k] = packet[k] ^ S[ (S[i]+S[j])%N ];

  	  	if (dp.getLevel()>4)
		  dp.printDebug2("encrypt: PRGA: i=%d j=%d K=%s S=[%s] Z=%02x\n",
			i,j, dp.printVector("", rc4_key).c_str(),
			dp.printVector("", S, N, " ").c_str(),
			S[ (S[i]+S[j])%N ]);
	}
}

#ifdef TEST
void usage(char* progname) {
	fprintf(stderr, "\
			
Print IVs from frames destined to/from the specified MAC address and encrypted with the specified keyID.

Usage: %s -i ifname -m mac -k keyid
       %s -f fname  -m mac -k keyid

Options:
         -i ifname  Capture frames from interface ifname
	 -f fname   Read frames from tcpdump formated file fname
	 -m mac     Use only frames destined to/from this MAC address
	 -k keyid   Use only frames encrypted with this keyid
", basename(progname), basename(progname));
}

#define IF_LEN 256
#define FN_LEN 256
#define MAC_LEN 12+5

#define BAD_ARG 1

int main(int argc, char* argv[]) {
	int option;

	char ifname[IF_LEN+1];	bool ifname_set = false;
	char fname[FN_LEN+1];	bool fname_set = false;
	char mac[MAC_LEN+1];	bool mac_set = false;
	int keyid;		bool keyid_set = false;

	// parse command line arguments
	while ((option = getopt(argc, argv, "i:f:m:k:h")) != -1) 
		switch(option) {
			case 'i': // interface name
				assert(optarg);
				if (strlen(optarg)>IF_LEN) {
					fprintf(stderr, "Error: interface name too long: %s\n", optarg);
					exit(BAD_ARG);
				}
				strncpy(ifname, optarg, IF_LEN+1);
				ifname_set = true;
				break;
				
			case 'f': // file name
				assert(optarg);
				if (strlen(optarg)>FN_LEN) {
					fprintf(stderr, "Error: file name too long: %s\n", optarg);
					exit(BAD_ARG);
				}
				strncpy(fname, optarg, FN_LEN+1);
				fname_set = true;
				break;
				
			case 'm': // mac address
				assert(optarg);
				if (strlen(optarg)!=MAC_LEN) {
					fprintf(stderr, "Error: bad mac address lenght: %s\n", optarg);
					exit(BAD_ARG);
				}
				strncpy(mac, optarg, MAC_LEN+1);
				mac_set = true;
				break;

			case 'k': // key id
				assert(optarg);
				keyid = strtol(optarg, NULL, 10);
				if (keyid == LONG_MIN || keyid == LONG_MAX) {
					fprintf(stderr, "Error: bad keyid value: %s\n", optarg);
					exit(BAD_ARG);
				}
				keyid_set = true;
				break;

			case 'h': // print help 
			default :
				usage(argv[0]); 
				exit(BAD_ARG);
		}

	// check if we have all the options we need
	if (!ifname_set && !fname_set) {
		fprintf(stderr, "Error: either interface name or file name need to be set\n");
		usage(argv[0]);
		exit(BAD_ARG);
	}
	
	if (ifname_set && fname_set) {
		fprintf(stderr, "Error: both interface and file name set\n");
		usage(argv[0]);
		exit(BAD_ARG);
	}

	if (!mac_set) {
		fprintf(stderr, "Error: mac address not set\n");
		usage(argv[0]);
		exit(BAD_ARG);
	}

	if (!keyid_set) {
		fprintf(stderr, "Error: keyid not set\n");
		usage(argv[0]);
		exit(BAD_ARG);
	}
	

	CaptureGenerator* cg;
	
	try {
		cg = new CaptureGenerator( (ifname_set) ? ifname:fname, mac, keyid, ifname_set);
	} catch (runtime_error& e) {
		printf("%s\n", e.what());
		exit(BAD_ARG);
	}
	
	WeakIV iv;
	while(cg->getWeakIV(iv)) {
		printf("IV: ");
		for(int i=0; i<iv.IV.size(); i++)
			printf("%02x", iv.IV[i]);
		printf(" Z: %02x\n", iv.Z);
	}
}
#endif
