import java.util.*;

public class SemaphoreTester {

    public static boolean assertTest(boolean test, String errorMessage, int pointsSoFar) {
        assert test : errorMessage + "  Total Points: " + pointsSoFar; 
        if (!test) {
            System.out.println("ERROR: " + errorMessage);
        }
        return test;
    }
    
    public static final int DEFAULT_NUM_THREADS = 15000;
    
    public static void main(String[] args) {
        //System.out.println();
        
        //set the number of threads for the stress test
        int numStressTestThreads = DEFAULT_NUM_THREADS;
        if (args.length > 0) {
            try {
                numStressTestThreads = Math.max(1, Integer.parseInt(args[0]));
            } catch (Exception e) {
                System.out.println(e);
            }
        }
        
        int partPoints;
        int expectedPoints = 0;
        boolean partPassed = true;

        //this test checks that the number of tokens right after a semaphore is created is correct.
        for (int i = -1; i < 20; i ++) {
            //assert (new MySemaphore(i).getNumTokens() == i) : "Constructor incorrect when given " + i + " tokens!";
            partPassed = partPassed && assertTest((new MySemaphore(i).getNumTokens() == i), "Constructor incorrect when given " + i + " tokens!", expectedPoints);
        }
        //System.out.println("The Constructor seems to work okay.");
        
        
        //this test checks that a Semaphore with positive tokens will never block.
        int numStartTokens = 100;
        final MySemaphore s0 = new MySemaphore(numStartTokens);
        for (int i = 0; i < numStartTokens; i++) {
            s0.p();
        }
        for (int i = 0; i < numStartTokens; i++) {
            s0.v();
        }
        for (int i = 0; i < numStartTokens; i++) {
            s0.p();
        }
        for (int i = 0; i < numStartTokens; i++) {
            s0.v();
            s0.p();
        }
        partPoints = 20;
        if (partPassed && numStressTestThreads >= DEFAULT_NUM_THREADS) {
            expectedPoints += partPoints;
            System.out.println("Looks like the semaphore doesn't block when there are a positive number of tokens, good.  +" + partPoints + " points!");
        }
        partPassed = true; //reset
        
        
        //this test checks whether P blocks when there are no available tokens and that v adds a token.
        final MySemaphore s = new MySemaphore(0);
        final ArrayList<String> flags = new ArrayList<>();
        Thread tryP = new Thread(() -> {s.p(); flags.add("P finished");});
        tryP.start();
        //System.out.println(s.getNumTokens());

        try {
            Thread.sleep(1000);
        } catch (Exception e) {
            System.out.println(e);
        }
        System.out.println("Sleep finished");

        flags.add("Sleep finished");
        s.v();
        //System.out.println(s.getNumTokens());


        partPassed = partPassed && assertTest(flags.get(0).equals("Sleep finished"),  "p() didn't block on newly-created semaphore with 0 tokens!", expectedPoints);
        System.out.println("Looks like p() blocks threads when a semaphore is created with no tokens, good!");
        
        
        try {
            tryP.join();
        } catch (Exception e) {
            System.out.println(e);
        }
        partPassed = partPassed && assertTest(flags.get(1).equals("P finished"), "The thread executing p() never finished correctly!", expectedPoints);
        //assert s.getNumTokens() == 0 : "There aren't zero tokens after a v() and a p().";
        partPassed = partPassed && assertTest(s.getNumTokens() == 0,  "There aren't zero tokens after a v() and a p().", expectedPoints);
        partPoints = 20;
        if (partPassed && numStressTestThreads >= DEFAULT_NUM_THREADS) {
            expectedPoints += partPoints;
            System.out.println("The p() method blocks when there are no available tokens, good!  +" + partPoints + " points!");
        }
        partPassed = true; //reset

        
        //this test checks that P blocks on zero or fewer tokens
        int numThreads = 100;
        final MySemaphore s2 = new MySemaphore(-numThreads);
        s2.v(); // add one token back
        Thread tryP2 = new Thread(() -> {s2.p(); flags.add("P finished");});
        tryP2.start();
        
        //System.out.println("tryP2 launched!");

        try {
            Thread.sleep(100);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        flags.add("Sleep finished");

        partPassed = partPassed && assertTest(flags.get(2).equals("Sleep finished"), "P is not blocking on zero-or-fewer tokens!", expectedPoints);
        //assert flags.get(2).equals("Sleep finished") : "P is not blocking on zero-or-fewer tokens!";


        //final long startTime = System.currentTimeMillis();

        //this test checks that lots of v() calls will increment the token appropriately
        System.out.println("Running Multi-v Test...");
        List<Thread> threads = new ArrayList<Thread>(numThreads);
        for (int i = 0; i < numThreads; i++) {
            Thread t = new Thread(() -> {s2.v();});
            threads.add(t);
            t.start();
        }
        
        try {
            tryP2.join();
        } catch (Exception e) {
            System.out.println(e);
        }
        partPassed = partPassed && assertTest(s.getNumTokens() == 0, "The semaphore ended up with the wrong number of tokens!", expectedPoints);
        //assert s.getNumTokens() == 0 : "The semaphore ended up with the wrong number of tokens!";
        partPassed = partPassed && assertTest(flags.get(3).equals("P finished"), "Threads got executed out of order!", expectedPoints);
        //assert flags.get(3).equals("P finished") : "Threads got executed out of order!";


        
        
        try {
            for (int i = 0; i < threads.size(); i++) {
                threads.get(i).join();
            }
        } catch (Exception e) {
            System.out.println(e);
        }

        //final long endTime = System.currentTimeMillis();
        //long elapsed = endTime - startTime;
        //int targetMillis = 2500;
        /*
        System.out.println("That took: " + elapsed + " milliseconds.");
        System.out.println("Target time: " + targetMillis + " milliseconds.");
        System.out.println( elapsed <= targetMillis ? "You got it in the target time, great!" : "That took a long time!");
        */
        if (partPassed && numStressTestThreads >= DEFAULT_NUM_THREADS) {
            expectedPoints += 10;
            System.out.println("Passed the first half of the stress test!  +10 points!");
        }
        partPassed = true; //reset
        
        
        final long startTime = System.currentTimeMillis();
        //this test just throws a bunch of p()->v() threads at a Semaphore and hopes they make it.
        final MySemaphore s3 = new MySemaphore(10);
        numThreads = 10000;
        
        threads = new ArrayList<Thread>(numThreads);
        int startNumTokens = s3.getNumTokens();
        for (int i = 0; i < startNumTokens; i++) {
            Thread t = new Thread(() -> {s3.p();});
            threads.add(t);
            t.start();
        }
        
        try {
            for (int i = 0; i < threads.size(); i++) {
                threads.get(i).join();
            }
        } catch (Exception e) {
            System.out.println(e);
        }
        System.out.println("Launched and completed the initial threads...");
        
        //s3 should have zero tokens now
        partPassed = partPassed && assertTest( s3.getNumTokens() == 0, "Multiple p()s don't remove the appropriate number of tokens!", expectedPoints);
        
        
        
        System.out.println("Running the bigger half of the stress test with " + numStressTestThreads + " threads..."); 
        threads = new ArrayList<Thread>();
        //create disrupting threads that are just going to add extra work to the semaphore.
        int numDisruptorThreads = numStressTestThreads;
        for (int i = 0; i < numDisruptorThreads; i++) {
            Thread disruptor = new Thread(() -> {
                for (int j = 0; j < 2; j++) {
                    s3.p();
                    try {
                        Thread.sleep(0, 1); //hopefully we'll give up processing to another thread
                    } catch (Exception e) {
                        System.out.println(e);
                    }
                    s3.v();
                }
            });
            threads.add(disruptor);
            disruptor.start();
        }
        
        int numThreads2 = numStressTestThreads; // if you're getting an out of memory error, drop this a bit.
        //System.out.println("Running stage 2...");
        MySemaphore s4 = new MySemaphore(1-numThreads2);
        for (int i = 0; i < numThreads2; i++) {
            Thread t = new Thread(() -> {
                s3.p();
                s4.v();
                //System.out.println(s4.getNumTokens()); //uncomment this if this part is hanging
                s3.v();
            });
            threads.add(t);
            t.start();
        }
        System.out.println("Launched the p -> v threads...");
        
        partPassed = partPassed && assertTest(s3.getNumTokens() == 0, "s3 doesn't have the correct number of tokens at this point!", expectedPoints);
        //assert s3.getNumTokens() == 0 : "s3 doesn't have the correct number of tokens at this point!";
        
        s3.v();
        System.out.println("Opened the floodgates..."); //uncomment the above print statement if it's hanging here!
        
        s4.p(); //waits until s4 goes up to 1 token (from negative many)
        
        
        for (Thread t : threads) {
            try {
                t.join();
            } catch (Exception e) {
                System.out.println(e);
            }
        }
        System.out.println("All other threads should be complete!");
        s3.p();
        
        final long endTime = System.currentTimeMillis();
        final long millis = endTime - startTime;
        
        partPassed = partPassed && assertTest(s4.getNumTokens() == 0, "Semaphore s4 didn't end up with the correct number of tokens.", expectedPoints);
        //assert s4.getNumTokens() == 0 : "s4 didn't end up with the correct number of tokens.";
        
        partPassed = partPassed && assertTest( s3.getNumTokens() == 0, "Semaphore s3 didn't end up with the correct number of tokens.", expectedPoints);
        //assert s3.getNumTokens() == 0 : "s3 didn't end up with the correct number of tokens.";
        
        System.out.println("The main testing thread made it to the end, that's something!");
        if (partPassed && numStressTestThreads >= DEFAULT_NUM_THREADS) {
            expectedPoints += 10;
            System.out.println("Passed the first half of the stress test!  +10 points!");
        }
        partPassed = true; //reset
        
        
        System.out.println("Completed the stress test in " + millis + " milliseconds.");
        //award some bonus points
        int bonusPoints = 0;
        if (numStressTestThreads >= DEFAULT_NUM_THREADS) {
            if (millis < 70000) {
                System.out.println("Your code completed in 70 seconds or less!  If this happens when Kyle grades your code, you'll get 5 bonus points.");
                bonusPoints = 5;
            } else if (millis < 90000) {
                System.out.println("Your code completed in 90 seconds or less!  If this happens when Kyle grades your code, you'll get 2 bonus points.");
                bonusPoints = 2;
            }
        }
        
        try {
            assert false;
            System.out.println("In order to get an expected score, you'll need to run this with assertions enabled, like this: \n$ java -ea SemaphoreTester\n(The $ is the prompt symbol; you don't actually type that.)");
        } catch (Error e) {
            if (numStressTestThreads >= DEFAULT_NUM_THREADS) {
                System.out.println("Just based on these tests, you'll earn " + expectedPoints + " points available from tests.  (I will also have to look at your code to check the other parts that can't be tested!)");
                if (bonusPoints > 0) {
                    System.out.println("(I expect you will earn an extra " + bonusPoints + " points for a total of " + (expectedPoints + bonusPoints) + " before I look for the other parts.)");
                }
            } else {
                System.out.println("In order to get an expected score, you'll need to use at least the default number of threads (" + DEFAULT_NUM_THREADS + ").  If you can't run that yourself, make sure you fix all the errors pointed out here, then submit it to canvas for testing.");
            }
        }
        
        //System.out.println("Unofficial score: " + expectedPoints);
        
        
    }
}