package com.knutejohnson.pi;

import java.io.*;
import java.net.*;
import java.util.*;
import java.util.concurrent.*;

/**
 * Command line Java program to locate servers by a scan of a sequence of IP
 * addresses utilizing an ICMP ping or a TCP connection on port 7 (echo), or a
 * TCP connection on a specified port.
 * <p>
 * java -jar IPScanner.jar startAddress endAddress [port or ping] [timeout]
 * <p>
 * Default is ping and 250ms timeout.
 * <p>
 * Under Linux the InetAddress#isReachable method, used for ping mode, uses a
 * TCP connection on port 7 and requires that both ends of the connection have
 * port 7 open through the firewall.  !CAUTION!  There are apparently some
 * security risks having port 7 open for UDP traffic.
 * <p>
 * @author  Knute Johnson
 * @version 0.34 - 19 September 2018
 */
public class IPScanner {
    /** Program version */
    public static final String VERSION = "0.34";

    /** Program date */
    public static final String DATE = "19 September 2018";

    /**
     * Command line program entry point
     *
     * @param   args command line arguments
     */
    public static void main(String... args) {
        System.out.printf("IPScanner Version: %s - Date: %s\n",VERSION,DATE);

        try {
            if (args.length >= 2 && args.length <= 4) {
                long start = System.currentTimeMillis();
                byte[] startAddr = InetAddress.getByName(args[0]).getAddress();
                byte[] endAddr = InetAddress.getByName(args[1]).getAddress();
                Queue<InetAddress> queue = new ConcurrentLinkedQueue<>();
                Map<InetAddress,String> map =
                 new ConcurrentSkipListMap<>(new InetAddressComparator());

                // default port is -1 or ping
                int port = -1;
                if (args.length >= 3) {
                    if (!args[2].equalsIgnoreCase("ping")) {
                        try {
                            port = Integer.parseInt(args[2]);
                        } catch (NumberFormatException nfe) {
                            System.out.println("bad port: using default: ping");
                        }
                    }
                }
                final int finalPort = port;

                // default timeout is 250ms
                int timeout = 250;
                if (args.length == 4)
                    try {
                        timeout = Integer.parseInt(args[3]);
                    } catch (NumberFormatException nfe) {
                        System.out.println("bad timeout: using default: 250ms");
                    }
                final int finalTimeout = timeout;

                // increment end address to make inclusive
                incIPArray(endAddr);
                while (!Arrays.equals(startAddr,endAddr)) {
                    queue.add(InetAddress.getByAddress(startAddr));
                    incIPArray(startAddr);
                }

                Thread[] threads = new Thread[Math.min(queue.size(),50)];
                for (int i=0; i<threads.length; i++) {
                    threads[i] = new Thread(() -> {
                        InetAddress inet;
                        while ((inet = queue.poll()) != null) {
                            try {
                                // open a socket to the specified port
                                if (finalPort >= 0) {
                                    Socket socket = new Socket();
                                    socket.connect(new InetSocketAddress(
                                     inet,finalPort),finalTimeout);
                                    socket.close();
                                    map.put(inet,inet.getHostAddress());
                                // or try a ping or echo
                                } else {
                                    if (inet.isReachable(finalTimeout))
                                        map.put(inet,inet.getHostAddress());
                                    else
                                        map.put(inet,"not reachable");
                                }
                            } catch (IOException ioe) {
                                map.put(inet,ioe.toString());
                            }
                        }
                    });
                    threads[i].start();
                }

                // wait for the threads above to finish
                for (Thread thread : threads)
                    thread.join();

                // print the results
                map.entrySet().stream().forEach(System.out::println);

                long stop = System.currentTimeMillis();
                System.out.printf("Search time: %.3f seconds\n",
                 (stop - start) / 1000.0);
            } else {
                System.out.println("Usage: java -jar IPScanner.jar " +
                 "startAddress endAddress [portNumber or ping] [timeout]");
            }
        } catch (UnknownHostException uhe) {
            System.out.println("bad IP address(es): terminating");
        } catch (InterruptedException ie) {
            ie.printStackTrace();
        }
    }

    /**
     * Increments a value stored in a byte array
     *
     * @param   array byte array to increment
     */
    private static void incIPArray(byte[] array) {
        boolean carryFlag = true;

        for (int i=array.length-1; i>=0; i--) {
            if (!(carryFlag && ++array[i] == 0))
                carryFlag = false;
        }
    }

    /**
     * Comparator used to order InetAddress objects by their IP addresses.
     */
    private static class InetAddressComparator implements
     Comparator<InetAddress> {
        /**
         * Compare two InetAddresses by comparing their IP addresses byte for
         * byte.
         *
         * @param a1 first InetAddress to be compared
         * @param a2 second InetAddress to be compared
         *
         * @return an integer to specify whether the first InetAddress is less
         *          than, equal to, or greater than the second InetAddress
         */
        public int compare(InetAddress a1, InetAddress a2) {
            byte[] arr1 = a1.getAddress();
            byte[] arr2 = a2.getAddress();

            for (int i=0; i<arr1.length; i++) {
                if ((arr1[i] & 0xff) - (arr2[i] & 0xff) == 0)
                    continue;
                if ((arr1[i] & 0xff) - (arr2[i] & 0xff) < 0)
                    return -1;
                else
                    return 1;
            }
            return 0;
        }
    }
}