/* Rev:$Revision: 1.0 $ */
/******************************************************************************
 * Copyright 2009 NetBurner, Inc.  ALL RIGHTS RESERVED
 * Permission is hereby granted to purchasers of NetBurner Hardware
 * to use or modify this computer program for any use as long as the
 * resultant program is only executed on NetBurner provided hardware.
 *
 * No other rights to use this program or it's derivitives in part or
 * in whole are granted. It may be possible to license this or other NetBurner
 * software for use on non-NetBurner hardware. Please contact sales@netburner.com
 * for more infomation.
 *
 * NetBurner makes no representation or warranties with respect to the
 * performance of this computer program, and specifically disclaims
 * any responsibility for any damages, special or consequential,
 * connected with the use of this program.
 *
 * NetBurner, Inc.
 * 5405 Morehouse Dr
 * San Diego Ca, 92131
 *
 *****************************************************************************/
/*-------------------------------------------------------------------
 * Windows PC TCP Speed Test application.
 *------------------------------------------------------------------*/
#include 
#include 
#include 
#include 
#include 

#define PRINTERROR(s) fprintf(stderr,"\n%s: %d\n", s, WSAGetLastError())
#define LISTEN_PORT_NUMBER (1234)  // Target's listen port number
#define BUFFER_SIZE 100000
#define NUM_OF_BYTES 10000000

char Buffer[BUFFER_SIZE + 10];

int main(int argc, char ** argv)
{
   WORD        wVersionRequested = MAKEWORD(1,1);
   WSADATA     wsaData;
   LPHOSTENT   lpHostEntry;
   int         nRet;
   SOCKET      mySocket;
   SOCKADDR_IN saServer;
   const char *ipAddr;
   bool bRepeat = FALSE;
   double fMax   = 0.0;
   double fMin   = 9999999.0;
   double nTotal = 0;

   // Initialize WinSock and check the version
   nRet = WSAStartup( wVersionRequested, &wsaData );
   if ( wsaData.wVersion != wVersionRequested )
   {
      fprintf( stderr, "\n ERROR: Incompatible WinSock Version\n" );
      return -1;
   }

   if ( argc <= 1 )
   {
	   printf("Usage: SpeedTest \n" );
	   return -1;
   }

   for ( int i = 1; i < argc; i++ )
   {
      if ( argv[i][0] == '-' )
      {
         if ( ( argv[i][1] == 'r' ) || ( argv[i][1] == 'R' ) )
         {
            bRepeat = TRUE;
         }
      }
      else
         ipAddr = argv[i];
   }

   printf("Get Host by Name for: \"%s\"\n", ipAddr);
   lpHostEntry = gethostbyname(ipAddr);
   if (lpHostEntry == NULL)
   {
      PRINTERROR("A: gethostbyname()");
      return -1;
   }

   do
   {
	  // Create a TCP/IP stream socket
	  mySocket = socket(AF_INET,				// Address family
	                    SOCK_STREAM,			// Socket type
                        IPPROTO_TCP);			// Protocol

	  if (mySocket == INVALID_SOCKET)
      {
         PRINTERROR("B: socket()");
         return -1;
      }

      // Fill in the address structure
      saServer.sin_family = AF_INET;
      saServer.sin_addr = *((LPIN_ADDR)*lpHostEntry->h_addr_list); // Server's address
      saServer.sin_port = htons(LISTEN_PORT_NUMBER);   // Port number from command line

      // Connect to the target device
	  nRet = connect(mySocket,                   // Socket
                     (LPSOCKADDR)&saServer,	     // Server address
                     sizeof(struct sockaddr) );  // Length of server address structure

      if ( nRet == SOCKET_ERROR )
      {
         PRINTERROR("C: socket()");
         closesocket(mySocket);
         return -1;
      }

      //========== Start Transmit Test from embedded device to PC ==========
	  {
         printf( "Starting Transmit Test\n" );
         int sv = 200000;   // size of socket receive buffer

         nRet = setsockopt( mySocket,          // Specify our socket
		                    SOL_SOCKET,        // Set socket options
	                        SO_RCVBUF,         // Total per-socket reserved buffer space
				            (const char *)&sv, // Pointer to buffer size
							sizeof(sv) );      // Length of option, sv
         printf( "Set Socket option return value = %d\n", nRet );

         // Start transmit test
         long BytesRead = 0;
		 send( mySocket, "T", 1, 0 );
         DWORD StartTime = GetTickCount();  // Start time in ms
         while ( BytesRead < NUM_OF_BYTES )
         {
            nRet = recv( mySocket, Buffer, BUFFER_SIZE, 0);
            if ( nRet > 0 )
			{
               BytesRead += nRet;
			}
			else
			{
				printf("BytesRead: %ld, nRet: %d\n", BytesRead, nRet);
				break;
			}
         }

         DWORD EndTime = GetTickCount();  // End time in ms
         printf( "Complete\n" );
		 closesocket( mySocket );

		 double TestTime = (double)( EndTime - StartTime );
         TestTime /= 1000.0;  // Convert to seconds
         double DataRate = (double)BytesRead / TestTime;
         DataRate = DataRate*8.0/1000000.0;  // convert to M bits/sec

         // Update high/low watermarks for repeat tests
		 if ( DataRate > fMax )
            fMax = DataRate;
		 if ( DataRate < fMin )
	        fMin = DataRate;

		 nTotal += BytesRead;

		 if ( bRepeat )
		 {
            printf("Test Time : %g seconds\n", TestTime );
            printf("Data Read : %ld bytes\n", BytesRead );
            printf("Data Rate : %g Mbps\n", DataRate );
			printf("Max Data Rate: %g, Min Data Rate: %g\n", fMax, fMin );
		 }
         else
		 {
		    printf("Test Complete\n");
			printf("Test Time : %g seconds\n", TestTime );
			printf("Data Read : %ld bytes\n", BytesRead );
            printf("Data Rate : %g Mbps\n", DataRate );
         }
      }


      printf("Delay before RX test\n");
	  Sleep(5000);

      //========== Start Receive Test from PC to embedded device ==========

	  // Create a TCP/IP stream socket
	  mySocket = socket(AF_INET,				// Address family
	                    SOCK_STREAM,			// Socket type
                        IPPROTO_TCP);			// Protocol

	  if (mySocket == INVALID_SOCKET)
      {
         PRINTERROR("B: socket()");
         return -1;
      }

      // Fill in the address structure
      saServer.sin_family = AF_INET;
      saServer.sin_addr = *((LPIN_ADDR)*lpHostEntry->h_addr_list); // Server's address
      saServer.sin_port = htons(LISTEN_PORT_NUMBER);   // Port number from command line

      // Connect to the target device
	  nRet = connect(mySocket,                   // Socket
                     (LPSOCKADDR)&saServer,	     // Server address
                     sizeof(struct sockaddr) );  // Length of server address structure

      if ( nRet == SOCKET_ERROR )
      {
         PRINTERROR("C: socket()");
         closesocket(mySocket);
         return -1;
      }

	  {
         printf( "Starting Receive Test\n" );
         long BytesSent = 0;
         int sv = 200000;   // size of socket receive buffer

         nRet = setsockopt( mySocket,          // Specify our socket
		                    SOL_SOCKET,        // Set socket options
	                        SO_RCVBUF,         // Total per-socket reserved buffer space
				            (const char *)&sv, // Pointer to buffer size
							sizeof(sv) );      // Length of option, sv
         printf( "Set Socket option = %d\n", nRet );

         for ( int i = 0; i < BUFFER_SIZE; i++ )
			 Buffer[i] = i % 255;

		 send( mySocket, "R", 1, 0 );

         DWORD StartTime = GetTickCount();  // Start time in ms
         while ( BytesSent < NUM_OF_BYTES )
         {
            nRet = send( mySocket, Buffer, BUFFER_SIZE, 0);
            if ( nRet > 0 )
			{
               BytesSent += nRet;
			}
            else
            {
               printf( "Failed \r\n" );
			   exit(-1);
               return 0;
            }
         }

         DWORD EndTime = GetTickCount();  // End time in ms
         closesocket( mySocket );
         printf( "Complete\n" );

		 double TestTime = (double)( EndTime - StartTime );
         TestTime /= 1000.0;  // Convert to seconds
         double DataRate = (double)BytesSent / TestTime;
         DataRate = DataRate*8.0/1000000.0;  // convert to M bits/sec

         // Update high/low watermarks for repeat tests
		 if ( DataRate > fMax )
            fMax = DataRate;
		 if ( DataRate < fMin )
	        fMin = DataRate;

		 nTotal += BytesSent;

		 if ( bRepeat )
		 {
            printf("Test Time : %g seconds\n", TestTime );
            printf("Data Sent : %ld bytes\n", BytesSent );
            printf("Data Rate : %g Mbps\n", DataRate );
			printf("Max Data Rate: %g, Min Data Rate: %g\n", fMax, fMin );
		 }
         else
		 {
		    printf("Test Complete\n");
			printf("Test Time : %g seconds\n", TestTime );
			printf("Data Sent : %ld bytes\n", BytesSent );
            printf("Data Rate : %g Mbps\n", DataRate );
         }
      }

   } while ( bRepeat == TRUE );
   return 0;
}