/* -
 * Copyright (c) 1998-2005 Joao Cabral
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS  SOFTWARE  IS  PROVIDED  BY  THE UNIVERSITY OF BRADFORD AND THE AUTHOR 
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 
 * TO, THE  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 
 * PURPOSE  ARE DISCLAIMED.  IN NO EVENT SHALL THE UNIVERSITY OR THE AUTHOR BE
 * LIABLE  FOR  ANY  DIRECT,  INDIRECT,  INCIDENTAL,  SPECIAL,  EXEMPLARY,  OR 
 * CONSEQUENTIAL  DAMAGES  (INCLUDING,  BUT  NOT  LIMITED  TO,  PROCUREMENT OF 
 * SUBSTITUTE  GOODS  OR  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION)  HOWEVER  CAUSED  AND  ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT,  STRICT  LIABILITY,  OR  TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN  IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *
 *   DHIS(c) Dynamic Host Information System Release 5
 *
 */

#include "dhisd.h"
#include "network.h"

int udp_sock;
extern int rport;


/*
 * msg_size_by_opcode() - Returns the size of a structure of opcode type
 *
 */
int msg_size_by_opcode(int opcode) {

	switch(opcode) {
	case(ECHO_REQ): return(sizeof(echo_req_t));
	case(ECHO_ACK): return(sizeof(echo_ack_t));
	case(AUTH_REQ): return(sizeof(auth_req_t));
	case(AUTH_ACK): return(sizeof(auth_ack_t));
	case(R51_AUTH_ACK): return(sizeof(auth_ack_51_t));
	case(AUTH_DENY): return(sizeof(auth_deny_t));
	case(AUTH_SX): return(sizeof(auth_sendx_t));
	case(AUTH_SY): return(sizeof(auth_sendy_t));
	case(CHECK_REQ): return(sizeof(check_req_t));
	case(CHECK_ACK): return(sizeof(check_ack_t));
	case(OFFLINE_REQ): return(sizeof(offline_req_t));

	case(R4_ECHO_REQ): return(sizeof(r4_echo_req_t));
	case(R4_ECHO_ACK): return(sizeof(r4_echo_ack_t));
	case(R4_AUTH_REQ): return(sizeof(r4_auth_req_t));
	case(R4_AUTH_ACK): return(sizeof(r4_auth_ack_t));
	case(R4_AUTH_DENY): return(sizeof(r4_auth_deny_t));
	case(R4_AUTH_SX): return(sizeof(r4_auth_sendx_t));
	case(R4_AUTH_SY): return(sizeof(r4_auth_sendy_t));
	case(R4_CHECK_REQ): return(sizeof(r4_check_req_t));
	case(R4_CHECK_ACK): return(sizeof(r4_check_ack_t));
	case(R4_OFFLINE_REQ): return(sizeof(r4_offline_req_t));

	case(R3_ONLINE_REQ): return(sizeof(r3_online_req_t));
	case(R3_OFFLINE_REQ): return(sizeof(r3_offline_req_t));
	default:	return(0);
	}
}

/*
 * swap_int() - Swaps the byte order of an integer 
 *
 */
void swap_int(int *n) {

        unsigned char *p,a,b,c,d;
        p=(unsigned char *)n;
        a=*p++;b=*p++;c=*p++;d=*p;
        p=(unsigned char *)n;
        *p++ = d;*p++ = c;*p++ = b;*p = a;
}

/*
 * swap_msg() - Calls swap_int to n members of the message 
 *
 */
void swap_msg(int *m,int n) {

        int i;
        for(i=0;i<n;i++) {
                swap_int(m);
                m++;
        }
}

/*
 * little_entian() - Checks if the system is little endian. 
 *                   Returns 1 if so or 0 if not 
 *
 */
int little_endian(void) {

        int a=1;
        unsigned char *p;
        p=(unsigned char *)&a;
        if((int)p[0]==1) return(1);
        return(0);
}

/* 
 * convert_message() - Converts a message to big/little endian as required 
 *
 */
void convert_message(msg_t *p,int mode) {

	int opcode;

	if(mode==1)  	opcode=p->hdr.opcode;
	swap_msg((int *)&(p->hdr),4);
	if(mode==2) opcode=p->hdr.opcode;

        if(opcode >= ECHO_REQ ) 
		swap_int((int *)&(p->hdr.hostid));

	
	switch(opcode) {
	

	case(ECHO_REQ): break;
	case(ECHO_ACK): { echo_ack_t *p2; p2=(echo_ack_t *)p;
			  swap_int((int *)&(p2->oserial));
			  break;
			}

	case(AUTH_REQ):  { auth_req_t *p2; p2=(auth_req_t *)p;
			  swap_int((int *)&(p2->refresh));
			  break;
			}
	case(AUTH_ACK):  { auth_ack_t *p2; p2=(auth_ack_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}
	case(R51_AUTH_ACK):  { auth_ack_51_t *p2; p2=(auth_ack_51_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}
	case(AUTH_DENY): break;
	case(AUTH_SX): { auth_sendx_t *p2; p2=(auth_sendx_t *)p;
			  break;
			}
			 
	case(AUTH_SY): break;

	case(CHECK_REQ):  { check_req_t *p2; p2=(check_req_t *)p;
			  swap_int((int *)&(p2->next_check));
			  break;
			}

	case(CHECK_ACK):  { check_ack_t *p2; p2=(check_ack_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}

	case(OFFLINE_REQ):  { offline_req_t *p2; p2=(offline_req_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}

	/* R4 messages */
	case(R4_ECHO_REQ): break;
	case(R4_ECHO_ACK): { r4_echo_ack_t *p2; p2=(r4_echo_ack_t *)p;
			  swap_int((int *)&(p2->oserial));
			  break;
			}

	case(R4_AUTH_REQ):  { r4_auth_req_t *p2; p2=(r4_auth_req_t *)p;
			  swap_int((int *)&(p2->id));
			  break;
			}
	case(R4_AUTH_ACK):  { r4_auth_ack_t *p2; p2=(r4_auth_ack_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}
	case(R4_AUTH_DENY): break;
	case(R4_AUTH_SX): { r4_auth_sendx_t *p2; p2=(r4_auth_sendx_t *)p;
			  swap_int((int *)&(p2->id));
			  break;
			}
			 
	case(R4_AUTH_SY): break;

	case(R4_CHECK_REQ):  { r4_check_req_t *p2; p2=(r4_check_req_t *)p;
			  swap_int((int *)&(p2->next_check));
			  break;
			}

	case(R4_CHECK_ACK):  { r4_check_ack_t *p2; p2=(r4_check_ack_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}

	case(R4_OFFLINE_REQ):  { r4_offline_req_t *p2; p2=(r4_offline_req_t *)p;
			  swap_int((int *)&(p2->sid));
			  break;
			}


	/* R3 messages */
	case(R3_OFFLINE_REQ):  { r3_offline_req_t *p2; p2=(r3_offline_req_t *)p;
			  swap_int((int *)&(p2->id));
			  swap_int((int *)&(p2->pass));
			  break;
			}
	case(R3_ONLINE_REQ):  { r3_online_req_t *p2; p2=(r3_online_req_t *)p;
			  swap_int((int *)&(p2->id));
			  swap_int((int *)&(p2->pass));
			  break;
			}

	}
	return;
}

int get_serial(void) { static int s=0; return(++s); }

/*
 * net_init() - Initialises the socket descriptor for receiving
 *              UDP connections
 *
 * Updates:   udp_sock only
 *
 * Returns:   0 on success, 1 on error
 *
 */

int net_init(int port) {

        struct sockaddr_in sa;


        /* Create UDP socket */
        udp_sock=socket(AF_INET,SOCK_DGRAM,0);
        if(udp_sock<0) return(1);

        /* Bind the UDP socket */
        sa.sin_family=AF_INET;
        sa.sin_port=htons(port);
        sa.sin_addr.s_addr=INADDR_ANY;
        if(bind(udp_sock,(struct sockaddr *)&sa,sizeof(struct sockaddr_in)))
        { 
                close(udp_sock);
                return(1);
        }

        /* UDP socket is ready to receive messages */
        return(0);
}

/*
 * net_close() - Closes sockets associated with UDP incoming ports
 *
 * Updates:   udp_sock only
 *
 * Returns:   0
 *
 */
int net_close(void) {

        close(udp_sock);
        return(0);
}

/* 
 * net_check_message() - Returns 1 if there is a message to be read or 0
 *                           otherwise.
 *
 */
int net_check_message(void) {

   	fd_set readfds;
        struct timeval tv;

        /* Prepare for select */
        FD_ZERO(&readfds);
        FD_SET(udp_sock,&readfds);
        tv.tv_sec=0;
        tv.tv_usec=0;

        /* Check for new messages */
        if(select(udp_sock+1,&readfds,NULL,NULL,&tv)==-1) return(0);
        if(!FD_ISSET(udp_sock,&readfds)) return(0);
	return(1);
}

/* 
 * net_read_message() - Reads a message into *p and returns length read
 *			or 0 on error.
 *
 */
int net_read_message(msg_t *p,int *from) {

	int n,sl;
	struct sockaddr_in sa;

        /* Read message */
        sl=sizeof(struct sockaddr_in);
        n=recvfrom(udp_sock,p,MAX_MSG,0,(struct sockaddr *)&sa,&sl);
	if(n<=0 || n >MAX_MSG) return(0);

	DSYSLOG(1,(LOG_DEBUG,"net_read_message(): Message arrived from %s\n",
		inet_ntoa(sa.sin_addr)));
	
	/* Convert to big endian if necessary */
	if(little_endian()) convert_message(p,2);
	memcpy(from,&sa.sin_addr,sizeof(struct in_addr));
	return(n);
}

/* 
 * net_write_message() - Writes a message from *p and returns the number of
 *                       bytes sent or 0 on error.
 *
 */
int net_write_message(msg_t *p,int toaddr,int toport) {

 	struct sockaddr_in sa;
        int s;
	int len;
	int r;

	{ struct in_addr sa;
	sa.s_addr=toaddr;
	DSYSLOG(1,(LOG_DEBUG,"net_write_message(): Sending Message to %s\n",
		inet_ntoa(sa)));
	}

	p->hdr.version=DHIS_VERSION;
	p->hdr.serial=get_serial();
	p->hdr.rport=rport;
	
        /* set destination */
        if((s=socket(AF_INET,SOCK_DGRAM,0))<0) return(0);
        sa.sin_family=AF_INET;
        sa.sin_port=htons(toport);
        sa.sin_addr.s_addr=toaddr;

	/* Get message size */
	len=msg_size_by_opcode(p->hdr.opcode);

	/* Convert to big endian if necessary */
	if(little_endian()) convert_message(p,1);

        /* Send message request */
        r=sendto(s,(unsigned char *)p,len,0,(struct sockaddr *)&sa,
		sizeof(struct sockaddr_in));
        close(s);

	/* Convert back just in case */
	if(little_endian()) convert_message(p,2);

        return(r);

}



/* qrc_random() - Generates a random integer of n digits
 *		  n may be up to 1024
 */
void qrc_random(mpz_t x,int n) {

	char buff[1024],temp[128];
	static int seed=0;

	if(!seed) { seed++; srandom(time(NULL)); }
	memset(buff,0,256);	
	memset(temp,0,128);	

	do {
		sprintf(temp,"%lu",(unsigned long int)random());
		strcat(buff,temp);
		
	} while(strlen(buff) < n);
	buff[n]='\0';
	
	mpz_set_str(x,buff,10);
	return;
}
		



/* qrc_genkey() - Generates an integer of 100 digits being congruent 
 * 		  to 3 mod 4
 *
 */

void qrc_genkey(mpz_t k) {
	
	int flag=1;

	do {
	

	mpz_t a,b;

	/* Get a prime number */
	do qrc_random(k,100); while(!mpz_probab_prime_p(k,5));

	/* Now see if it is congruent to 3 mod 4 */
	mpz_init(a);mpz_init(b);
	mpz_set_ui(a,4);
	mpz_mod(b,k,a);
	mpz_set_ui(a,3);
	if(!mpz_cmp(a,b)) flag=0;
	mpz_clear(a);
	mpz_clear(b);

	} while(flag);

}

/* qrc_genx() - Geretates a random x relatively prime to n 
 *
 */
void qrc_genx(mpz_t x,mpz_t n) {

	int i;
	mpz_t t;

	i=mpz_sizeinbase(n,10);		/* Get size of n and take 1 */
	i--;	

	mpz_init(t);

	do {				/* Generate x of n-1 digits */
		qrc_random(x,i); 	
		qrc_geny(t,x,n);	/* square it modulo n to get */
		mpz_set(x,t);		/* quadractic residue */
		mpz_gcd(t,x,n);		

	} while(mpz_cmp_ui(t,1));	/* test relative primeness */

	mpz_clear(t);
}

 
/* qrc_geny() - y is the quadractic residue given by x^2 mod n
 *
 */
void qrc_geny(mpz_t y,mpz_t x,mpz_t n) { 

	mpz_powm_ui(y,x,2,n);
}

/* qrc_sqrty() - Calculates the square root of y mod k using a^((k+1)/4))mod k
 *
 */

void qrc_sqrty(mpz_t s,mpz_t y,mpz_t k) {

	mpz_t t1,t2,t3;

	mpz_init(t1);
	mpz_init(t2);
	mpz_init(t3);
	mpz_set(t1,k);
	mpz_set_ui(t3,4);
	mpz_add_ui(t2,t1,1); 		/* t2 = k+1 */
	mpz_divexact(t1,t2,t3); 	/* t1 = t2/4 */
	mpz_powm(s,y,t1,k);
	mpz_clear(t1);
	mpz_clear(t2);
	mpz_clear(t3);
}


/* qrc_crt() - Applies the Chinese remainder theorem and calculates x
 *
 */
void qrc_crt(mpz_t x,mpz_t xp,mpz_t xq,mpz_t p,mpz_t q) {

	mpz_t s,t,g1,g2;
	mpz_t temp;
	
	mpz_init(s);
	mpz_init(t);
	mpz_init(g1);
	mpz_init(g2);
	mpz_init(temp);
	
	/* Use Euclid's theorem to find s and t */
	mpz_gcdext(g1,s,t,q,p);
	
	mpz_mul(temp,xp,s);	/* Do g1 = x1.s.q */
	mpz_mul(g1,temp,q);
	
	mpz_mul(temp,xq,t);	/* Do g2 = x2.t.p */
	mpz_mul(g2,temp,p);
	
	mpz_add(x,g1,g2);	/* Do x = g1 + g2 */

	mpz_clear(temp);
	mpz_clear(s);
	mpz_clear(t);
	mpz_clear(g1);
	mpz_clear(g2);
	
}

/* qrc_fill_str() - This function fills a buffer pointed by str
 *		    with n digits of x. Adds 0's to the left if 
 *		    required.
 */
void qrc_fill_str(mpz_t x,unsigned char *str,int n) {
	
	int i,j;
	unsigned char buff[1024];
	unsigned char buff2[1024];
	unsigned char *cp1,*cp2;

	i=mpz_sizeinbase(x,10);		/* Get size of x */
	j=n-i;				/* j = number of 0's to add */
	if(j<0) return;

	buff[0]='\0';
	for(i=0;i<j;i++) strcat(buff,"0");	/* Place 0's */

	mpz_get_str(buff2,10,x);		/* Add x */
	strcat(buff,buff2);

	/* Now copy n digits to str */
	cp1=str;
	cp2=buff;
	for(i=0;i<n;i++) 
		*cp1++ = *cp2++;
	return;
}



