123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- import java.util.Random;
- import java.util.Arrays;
-
- class Main {
- static void usage() {
- System.out.println("Usage: java Main <n> <k> [seq|par]");
- System.out.println(" or: java Main test");
- }
-
- static final int randSeed = 1337; // Just to keep things consistent
-
- static class Timer {
- long start = 0;
- long time = 0;
-
- void start() {
- start = System.nanoTime();
- }
- void end() {
- time = System.nanoTime() - start;
- }
-
- String prettyTime() {
- if (time < 1000)
- return time+"ns";
- else if (time < 1000000)
- return (time / 1000)+"μs";
- else if (time < 1000000000)
- return (time / 1000000)+"ms";
- else
- return (time / 1000000000)+"s";
- }
-
- String prettySpeedup(Timer base) {
- double speedup = (double)base.time / (double)time;
- return String.format("%s (%.2fx speedup)",
- prettyTime(), speedup);
- }
- }
-
- static class TestResult {
- Timer baseTime = null;
- Timer seqTime = null;
- Timer parTime = null;
- int n = 0;
- int k = 0;
-
- // Print a test result
- void print() {
- String tmpl =
- "\nTest results for n=%d, k=%d:\n"+
- "\tArrays.sort: %s\n"+
- "\tSequential: %s\n"+
- "\tParallel: %s\n";
-
- System.out.printf(tmpl,
- n, k, baseTime.prettyTime(),
- seqTime.prettySpeedup(baseTime),
- parTime.prettySpeedup(baseTime));
- }
-
- // Create a TestResult which is the average
- // of n other TestResults
- static TestResult avg(TestResult[] results) {
- TestResult ts = new TestResult();
- ts.baseTime = new Timer();
- ts.seqTime = new Timer();
- ts.parTime = new Timer();
- ts.n = results[0].n;
- ts.k = results[0].k;
-
- for (TestResult t: results) {
- if (t.n != ts.n)
- throw new RuntimeException("Bad n value");
- if (t.k != ts.k)
- throw new RuntimeException("Bad k value");
-
- ts.baseTime.time += t.baseTime.time;
- ts.seqTime.time += t.seqTime.time;
- ts.parTime.time += t.parTime.time;
- }
-
- ts.baseTime.time /= results.length;
- ts.seqTime.time /= results.length;
- ts.parTime.time /= results.length;
- return ts;
- }
- }
-
- static TestResult runTest(int[] a, int n, int k, Timer baseTime) {
-
- TestResult ts = new TestResult();
-
- Timer seqTime = new Timer();
- Timer parTime = new Timer();
-
- A2Solver seq = new Sequential();
- A2Solver par = new Parallel();
-
- int[] copy = new int[n];
-
- // Test seq
- System.arraycopy(a, 0, copy, 0, n);
- seqTime.start();
- seq.solveA2(copy, k);
- seqTime.end();
-
- // Test par
- System.arraycopy(a, 0, copy, 0, n);
- parTime.start();
- par.solveA2(copy, k);
- parTime.end();
-
- ts.baseTime = baseTime;
- ts.seqTime = seqTime;
- ts.parTime = parTime;
- ts.n = n;
- ts.k = k;
- return ts;
- }
-
- static void runTests() {
- int[] ns = { 100000000, 10000000, 1000000, 100000, 10000, 1000 };
- int[] ks = { 100, 20 };
- int iters = 7;
-
- for (int n: ns) {
-
- Timer baseTime = new Timer();
-
- // Construct array
- Random rand = new Random(randSeed);
- int[] a = new int[n];
- for (int i = 0; i < n; ++i) {
- a[i] = rand.nextInt();
- }
-
- // Time Arrays.sort()
- // I only do this once per n because it takes a long time.
- int[] copy = new int[n];
- System.arraycopy(a, 0, copy, 0, n);
- baseTime.start();
- Arrays.sort(copy);
- baseTime.end();
-
- for (int k: ks) {
- TestResult[] results = new TestResult[iters];
- for (int i = 0; i < iters; ++i) {
- results[i] = runTest(a, n, k, baseTime);
- }
-
- TestResult t = TestResult.avg(results);
- t.print();
- }
- }
- }
-
- public static void main(String[] args) {
- if (args.length < 1) {
- usage();
- System.exit(1);
- }
-
- if (args[0].equals("test")) {
- runTests();
- return;
- }
-
- A2Solver solver;
- int n, k;
-
- try {
- n = Integer.parseInt(args[0]);
- k = Integer.parseInt(args[1]);
- if (args.length == 3) {
- if (args[2].equals("seq"))
- solver = new Sequential();
- else if (args[2].equals("par"))
- solver = new Parallel();
- else {
- usage();
- System.exit(1);
- return;
- }
- } else {
- solver = new Sequential();
- }
- } catch (Exception ex) {
- System.out.println(ex.getMessage());
- System.out.println("");
- usage();
- System.exit(1);
- return;
- }
-
- Timer timer = new Timer();
-
- // Create array
- int[] a = new int[n];
- Random r = new Random(randSeed);
- for (int i = 0; i < n; ++i) {
- a[i] = r.nextInt();
- }
-
- // Solve
- timer.start();
- solver.solveA2(a, k);
- timer.end();
-
- // Print
- for (int i = 0; i < k; ++i) {
- if (i != 0)
- System.out.print(" ");
- System.out.print(a[i]);
- }
- System.out.println("");
-
- // Print report info to stderr, so that
- // stdout is only the result
- System.err.printf("\nn=%d, k=%d, %s time: %s\n",
- n, k, solver.getClass().getName(), timer.prettyTime());
- }
- }
|