Java で Erlang 風プロセスを作ってみる

Erlang を使ったことはないから、Erlang 風かどうかは実際のところ分からないんだけど…。他の人が作ってるのを見た感じ、まぁ大体あってるんじゃないかなーという希望を持ちつつ。メッセージボックス付きのスレッドってな感じで作った。

import java.util.*;
import java.util.concurrent.*;

public class Process {

    private static final Map<Thread, Process> processes = new HashMap<Thread, Process>();
    private static final Object QUIT = new Object();

    public static final Object BYE = new Object();

    private final Thread thread;
    private final List<Process> children;
    private final BlockingQueue<Object> messageBox;

    protected Process(Thread thread) {
        this.thread = thread;
        this.children = new ArrayList<Process>();
        this.messageBox = new LinkedBlockingQueue<Object>();
    }

    public static Process getCurrent() {
        Thread t = Thread.currentThread();
        synchronized (processes) {
            Process p = processes.get(t);
            if (p == null) {
                p = new Process(t);
                processes.put(t, p);
            }
            return p;
        }
    }

    public static Process spawn(Runnable body) {
        Process p = new Process(new Thread(body) {
                @Override
                public void run() {
                    super.run();
                    synchronized (processes) {
                        processes.remove(this);
                    }
                }
            });

        Process current = getCurrent();
        current.children.add(p);
        synchronized (processes) {
            processes.put(p.thread, p);
        }
        p.thread.start();
        return p;
    }

    public static Object getMessage() {
        return getCurrent().messageBox.poll();
    }

    public static Object takeMessage() throws InterruptedException {
        return getCurrent().messageBox.take();
    }

    public static Object takeMessage(long timeout, TimeUnit unit) throws InterruptedException {
        return getCurrent().messageBox.poll(timeout, unit);
    }

    public void kill() {
        putMessage(QUIT);
    }

    public void killChildren() {
        for (Process c : children) {
            c.kill();
        }
    }

    public void putMessage(Object message) {
        if (message == QUIT) {
            killChildren();
            if (thread.isAlive()) {
                thread.interrupt();
            }
            synchronized (processes) {
                processes.remove(thread);
            }
        } else {
            messageBox.offer(message);
        }
    }

}

以下サンプル。素数の数を数えあげる処理を、上の機構を使って並列処理でやってみたもの。悲しいかな、シングルスレッド版のほうが速い。processes 周りで synchronized してるところが、だいぶ (鈍化に) 効いてるような気がしてる。

import java.util.*;

public class ConcurrentPrime {

    private static boolean isPrime(int n) {
        if (n == 2) return true;
        if (n % 2 == 0) return false;

        int m = 3;
        while (m <= Math.sqrt(n)) {
            if (n % m == 0) return false;
            m += 2;
        }
        return true;
    }

    private static class Calc implements Runnable {

        final Process counter;

        Calc(Process counter) {
            this.counter = counter;
        }

        public void run() {
            try {
                while (true) {
                    Object m = Process.takeMessage();
                    if (m == Process.BYE) {
                        break;
                    }
                    int n = ((Integer) m).intValue();
                    if (isPrime(n)) {
                        counter.putMessage(n);
                    }
                }
            } catch (InterruptedException quit) {
            } finally {
                counter.putMessage(Process.BYE);
            }
        }

    }

    private static class Counter implements Runnable {

        private final Process parent;
        private final int procs;

        public Counter(int procs) {
            this.parent = Process.getCurrent();
            this.procs = procs;
        }

        public void run() {
            try {
                int count = 0;
                int end = 0;
                while (end < procs) {
                    Object m = Process.takeMessage();
                    if (m == Process.BYE) {
                        ++end;
                    } else {
                        ++count;
                    }
                }
                System.out.println("prime count = " + count);
            } catch (InterruptedException quit) {
            } finally {
                parent.putMessage(Process.BYE);
            }
        }

    }

    private static class Balancer implements Runnable {

        private final Process parent;
        private final int max;
        private final int procLimit;

        public Balancer(int max, int procLimit) {
            this.max = max;
            this.procLimit = procLimit;
            this.parent = Process.getCurrent();
        }

        public void run() {
            Process counter = Process.spawn(new Counter(procLimit));
            counter.putMessage(2);

            List<Process> calcs = new ArrayList<Process>();
            for (int i = 0; i < procLimit; ++i) {
                calcs.add(Process.spawn(new Calc(counter)));
            }

            int step = max / procLimit;
            for (int i = 3; i < max; i += procLimit) {
                for (int j = i, k = i + procLimit; j < max && j < k; ++j) {
                    Process calc = calcs.get(j - i);
                    calc.putMessage(j);
                }
            }

            for (Process calc : calcs) {
                calc.putMessage(Process.BYE);
            }

            try {
                Process.takeMessage();
                parent.putMessage(Process.BYE);
            } catch (InterruptedException quit) {
            }
        }

    }

    private static void bench(String name, Runnable r) {
        long start = System.nanoTime();
        r.run();
        long end = System.nanoTime();

        System.out.printf("%s: %.2f ms%n", name, (end - start) / 1000000.0);
    }

    public static void main(String[] args) {
        int n = -1;
        int m = -1;
        if (0 < args.length) {
            try {
                n = Integer.parseInt(args[0]);
            } catch (NumberFormatException e) {
            }
            if (1 < args.length) {
                try {
                    m = Integer.parseInt(args[1]);
                } catch (NumberFormatException e) {
                }
            }
        }

        final int max = n > 0 ? n : 500000;
        final int procs = m > 0 ? m : 5;

        bench("concurrent " + max + " with " + procs, new Runnable() {
                public void run() {
                    Process.spawn(new Balancer(max, procs));
                    try {
                        Process.takeMessage();
                    } catch (InterruptedException quit) {
                    }
                }
            });

        bench("sequencial " + max, new Runnable() {
                public void run() {
                    int count = 0;
                    for (int i = 2; i <= max; ++i) {
                        if (isPrime(i)) ++count;
                    }
                    System.out.println("prime count = " + count);
                }
            });
    }

}

以下、実行結果。

$ java ConcurrentPrime 10000000 5
prime count = 664579
concurrent 10000000 with 5: 12244.45 ms
prime count = 664579
sequencial 10000000: 9181.09 ms

$ java ConcurrentPrime 10000000 8
prime count = 664579
concurrent 10000000 with 8: 12989.82 ms
prime count = 664579
sequencial 10000000: 9128.46 ms

$ java ConcurrentPrime 10000000 3
prime count = 664579
concurrent 10000000 with 3: 13655.63 ms
prime count = 664579
sequencial 10000000: 9190.02 ms

やっぱり並列処理はむずかしい。Java では大人しく java.util.concurrent を使うのがよさげかなぁ。

追記:
一度投稿したあと、よくよく見てみると色々と間違ってたので修正した。