前書き
いつもPHPカテゴリへの投稿ばかりで多分Javaは今回が初めてです。つまりJavaに関しては ド素人 です。加えて数学関連はわりと苦手なダメプログラマです。先日、大学で展開科目として取っている知能処理アルゴリズム論の講義で
問題空間(グラフ、コスト、ヒューリスティック値)を自分で設定し、 A*アルゴリズム で解け
という課題が出たのですが、どうせなら無意味なグラフを解くよりも具体的な問題が解きたいと思ったので講義中でも紹介されていた 8パズル を解くコードを書いてみることにしました。まだ理解に乏しい面があるので、通常より計算量が多くなってしまう実装になっているかもしれません。不適切なコーディングがありましたらコメントで生暖かく指摘お願いします(汗
内容
123456789x
の文字を全て1回ずつ使用し、第1引数にスタート状態・第2引数にゴール状態を指定して実行して、各ステップを見やすく出力します。
ソースコード
Node.java
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Node implements Comparable<Node> {
private final String nodeValue;
private final String goalValue;
private final int heuristicAmount;
private int passedAmount = 0;
private Node parent;
public Node(String nodeValue, String goalValue) {
this.nodeValue = nodeValue;
this.goalValue = goalValue;
this.heuristicAmount = calculateHeuristicAmount();
}
@Override
public int compareTo(Node node) {
return this.getPriority() - node.getPriority();
}
@Override
public String toString() {
String table = "-------\n";
for (int i = 0; i < 3; ++i) {
table += "|";
for (int j = i * 3; j < i * 3 + 3; ++j) {
table += nodeValue.charAt(j) + "|";
}
table += "\n-------\n";
}
return table;
}
public String toCompleteString() {
List<String> states = new ArrayList<>();
Node node = this;
do {
states.add(node.toString());
} while ((node = node.parent) != null);
Collections.reverse(states);
String str = "0.\n" + states.get(0);
for (int i = 1; i < states.size(); ++i) {
str += "\n" + i + ".\n" + states.get(i);
}
return str;
}
public int getHeuristicAmount() {
return heuristicAmount;
}
public int getPriority() {
return passedAmount + heuristicAmount;
}
public String getNodeValue() {
return nodeValue;
}
public List<Node> openChildNodes() {
int from = nodeValue.indexOf("x");
int[] offsets = {from - 3, from + 3, from - 1, from + 1};
List<Node> childNodes = new ArrayList<>();
for (int to : offsets) {
if (
to >= 0 &&
to <= 8 &&
!(from == 2 && to == 3) &&
!(from == 3 && to == 2) &&
!(from == 5 && to == 6) &&
!(from == 6 && to == 5)
) {
Node node = new Node(swapOffset(from, to), goalValue);
node.parent = this;
node.passedAmount = (parent != null ? parent.passedAmount : 0) + 2;
childNodes.add(node);
}
}
return childNodes;
}
private String swapOffset(int from, int to) {
char[] chars = nodeValue.toCharArray();
char tmp = chars[from];
chars[from] = chars[to];
chars[to] = tmp;
return String.valueOf(chars);
}
private int calculateHeuristicAmount() {
int sum = 0;
for (int i = 0; i < 9; ++i) {
int j = goalValue.indexOf(String.valueOf(nodeValue.charAt(i)));
sum += (int)(Math.abs(i - j) / 3) + Math.abs(i - j) % 3;
}
return sum;
}
}
EightPuzzle.java
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
public class EightPuzzle {
private final Queue<Node> nodes = new PriorityQueue<>();
private final Map<String, Node> visited = new HashMap<>();
public static void main(String[] args) {
try {
new EightPuzzle(args).solve();
} catch (IllegalArgumentException e) {
System.out.println(e.getMessage());
}
}
private EightPuzzle(String[] args) {
validateInput(args);
nodes.add(new Node(args[0], args[1]));
}
private void solve() {
Node node;
int counter = 0;
while ((node = nodes.poll()) != null) {
if (node.getHeuristicAmount() == 0) {
System.out.println(String.format("解が見つかりました。(展開回数: %d)", counter));
System.out.println();
System.out.println(node.toCompleteString());
break;
}
++counter;
for (Node opened : node.openChildNodes()) {
if (visited.containsKey(opened.getNodeValue())) {
if (opened.getPriority() < visited.get(opened.getNodeValue()).getPriority()) {
visited.remove(opened.getNodeValue());
visited.put(opened.getNodeValue(), opened);
nodes.add(opened);
}
} else {
visited.put(opened.getNodeValue(), opened);
nodes.add(opened);
}
}
}
if (node == null) {
System.out.println("このパズルは解けません。");
}
}
private void validateInput(String[] args) {
if (args.length != 2) {
throw new IllegalArgumentException("引数の数は2個にしてください。");
}
char[] srcChars, dstChars, validChars;
srcChars = args[0].toCharArray();
dstChars = args[1].toCharArray();
validChars = "12345678x".toCharArray();
Arrays.sort(srcChars);
Arrays.sort(dstChars);
Arrays.sort(validChars);
if (
!Arrays.equals(srcChars, dstChars) ||
!Arrays.equals(srcChars, validChars)
) {
throw new IllegalArgumentException("各引数は「12345678x」の文字を漏れなく1回ずつ使用してください。");
}
}
}
実行例
コマンドライン
java EightPuzzle 5862734x1 87654321x
実行結果
解が見つかりました。(展開回数: 16)
0.
-------
|5|8|6|
-------
|2|7|3|
-------
|4|x|1|
-------
1.
-------
|5|8|6|
-------
|2|7|3|
-------
|x|4|1|
-------
2.
-------
|5|8|6|
-------
|x|7|3|
-------
|2|4|1|
-------
3.
-------
|x|8|6|
-------
|5|7|3|
-------
|2|4|1|
-------
4.
-------
|8|x|6|
-------
|5|7|3|
-------
|2|4|1|
-------
5.
-------
|8|7|6|
-------
|5|x|3|
-------
|2|4|1|
-------
6.
-------
|8|7|6|
-------
|5|4|3|
-------
|2|x|1|
-------
7.
-------
|8|7|6|
-------
|5|4|3|
-------
|2|1|x|
-------