// Copyright (C) FIT VUT
// Petr Lampa <lampa@fit.vutbr.cz>
// $Id$
// vi:set ts=8 sts=8 sw=8:
//
// UDP protocol
//

#ifndef __NDWATCH_UDP_H__
#define __NDWATCH_UDP_H__

#include "ipv4.h"
#include "ipv6.h"

// +--------+--------+--------+--------+
// |     Source      |   Destination   |
// |      Port       |      Port       |
// +--------+--------+--------+--------+
// |     Length      |    Checksum     |
// +--------+--------+--------+--------+
// |          data octets ...
// +---------------- ...
// RFC 768
class UDP: public Packet
{
private:
	Short _sport;
	Short _dport;
	Short _length;	// full packet length (with header)
	Short _checksum;
	Octet *_data;
	bool _own;

	UDP(const UDP&);
	UDP& operator=(const UDP&);
public:
	UDP(): _sport(0), _dport(0), _length(0), _checksum(0),_data(0),_own(false) { }
	UDP(unsigned sport, unsigned dport): _sport(sport),_dport(dport),_length(0),_checksum(0),_data(0),_own(false) { }

	void set_data(Octet *data, Short length, bool own=false) { _data = data; _length = length+8; _own = own; }
	Octet *get_data() const { return _data; }
	Short src_port() const { return _sport; }
	Short dst_port() const { return _dport; }
	Short length() const { return _length-8; }

	string name() const { return "UDP"; }
	string to_string(int level=0) const;
	void fixup();
	bool do_build(Buffer &pkt, int phase, Word &pos);
	bool decode(Buffer &pkt);
	~UDP() { if (_data && _own) delete[] _data; }
};

string UDP::to_string(int level) const
{
	string str("<UDP sport=");
	str += cvt_int(_sport);
	str += " dport=";
	str += cvt_int(_dport);
	str += " length=";
	str += cvt_int(_length);
	str += " chk=";
	str += cvt_hex(_checksum);
	str += ">";
	if (level > 1 && _data && _length > 8) {
		str += '\n';
    		for (int i = 0; i < _length-8; i++) {
                	str += xdigits[_data[i]>>4];
                	str += xdigits[_data[i] & 0xf];
                	if ((i % 16) == 15) str += '\n';
                	else str += ' ';
		}
        	str += '\n';
	}
	return str;
}

// called after attach
void UDP::fixup()
{

}

bool UDP::do_build(Buffer &pkt, int phase, Word &pos)
{
	int i;
	Octet *p;
	Word checksum;
	if (phase == 0) {
		if (payload()) { // setup payload
			Word npos = 0;
			Buffer data;
			if (_data && _own) delete _data;
			if (!payload()->do_build(data, 1, npos)) return false;
			_data = new Octet[npos];
			memcpy(_data, data.buf(), npos);
			_length = 8 + npos;
		} else {	// payload prepared by set_data
		}
		checksum = underlayer()->pseudo_checksum(ProtoUDP, _length);
		checksum += _sport + _dport + _length;
		pos += _length;
		if (_length > 8 && _data) {
			p = _data;
			for (i = 0; i < _length-8-1; i+=2,p+=2) {
				checksum += (*p << 8) + p[1];
			}
			// last odd octet
			if (_length & 1) checksum += (*p++ << 8);
		}
		checksum = (checksum & 0xffff) + (checksum >> 16);
		if (checksum > 65535) checksum -= 65535;
		_checksum = ~checksum;
		return true;
	}
	if (!pkt.add_hshort(_sport)) return false;
	if (!pkt.add_hshort(_dport)) return false;
	if (!pkt.add_hshort(_length)) return false;
	if (!pkt.add_hshort(_checksum)) return false;
	pos += _length;
	if (payload()) {
		// no second phase
		// if (!payload()->do_build(pkt, phase, pos)) return false;
	} 
	if (_length > 8) {
		p = _data;
		for (i = 0; i < _length-8; i++) {
			if (_data) {
				if (!pkt.add_octet(*p++)) return false;
			} else {
				if (!pkt.add_octet(0)) return false;
			}
		}
	}
	return true;
}

bool UDP::decode(Buffer &pkt)
{
	Octet *p;
	if (_own && _data) delete _data;
	_data = 0;
	_own = false;
	if (!pkt.get_hshort(_sport)) {
		if (debug > 0) cerr << "UDP::decode(): missing source port octets" << endl;
		return false;
	}
	if (!pkt.get_hshort(_dport)) {
		if (debug > 0) cerr << "UDP::decode(): missing destination port octets" << endl;
		return false;
	}
	if (!pkt.get_hshort(_length)) {
		if (debug > 0) cerr << "UDP::decode(): missing length octets" << endl;
		return false;
	}
	if (!pkt.get_hshort(_checksum)) {
		if (debug > 0) cerr << "UDP::decode(): missing checksum octets" << endl;
		return false;
	}
	if (_length > 8) {
		_data = new Octet[_length-8];
		_own = true;
		p = _data;
		for (int i = 0; i < _length-8; i++) {
			if (!pkt.get_octet(*p++)) {
				if (debug > 0) cerr << "UDP::decode(): missing data octet " << i << endl;
				return false;
			}
		}
	}
	return true;
}

static Packet *UDP_factory(Buffer &pkt)
{
       return new UDP();
}

#endif
