コケッココケッコ

コケコッコー

お手持ちの画像で画像認識 by using TensorFlow (keras)

概要

手持ちの画像を使って画像認識したいニーズがあったので、Kerasを用いてDeep Learningすることにした。

とはいえ今回はDeep Learningとはいってもすべて全結合、隠れ層の活性化関数にsigmoid関数、出力層にsoftmax関数を用いた簡単なネットワークを作ってみた.

対象

わんこの画像を分類する.

tmp/train, tmp/test にそれぞれ学習画像、テスト画像が入っている。

各画像は以下のように犬 ${犬種}_{$id}.jpg といったファイル名で格納されている.

犬 チワワ1.jpg 犬 柴犬2.jpg

学習データは各犬種につき10枚程度。

ソース

# deep learningしてみるや〜つ
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.optimizers import SGD
import os

from keras.preprocessing.image import img_to_array, load_img
from tensorflow.contrib.keras.python.keras.utils import to_categorical

from src.util.fileutil import FileUtil


class ModelFinder:
    @staticmethod
    def getSimpleModel(n_in, n_hidden, num_hidden, n_out):
        """
        全結合のモデル
        :param n_in: 入力の次元
        :param n_hidden: 隠れ層の次元数
        :param num_hidden: 隠れ層の数
        :param n_out: カテゴリ数
        :return:
        """
        model = Sequential()

        # 入力層 - 隠れ層
        model.add(Dense(n_hidden, input_dim=n_in, activation='sigmoid'))

        # 隠れ層 - 隠れ層
        for i in range(num_hidden-1):
            model.add(Dense(n_hidden, activation='sigmoid'))

        # 隠れ層 - 出力層
        model.add(Dense(n_out, activation='softmax'))

        return model

def loadImg(dir, pic_row, pic_col):
    X_train = []
    Y_cat = []
    for filepath in FileUtil.list_file(dir):
        # 画像を配列にぶち込む.
        filename = os.path.basename(filepath)
        catname = filename.replace('犬 ', '')[:filename.find('_')-len('犬 ')]
        X_train.append(img_to_array(load_img(filepath, target_size=(pic_row, pic_col))).reshape(-1))
        Y_cat.append(catname)

        print("%s %s" % (filepath, catname))
    return (np.array(X_train), Y_cat)

def createCatDic(cats):
    ret = {}
    for cat in cats:
        if cat not in ret:
            ret[cat] = len(ret)
    return ret

if __name__ == '__main__':
    train_dir = '../tmp/train'
    test_dir = '../tmp/test'

    # 隠れ層の次元数
    n_hidden = 128

    # 隠れ層の数 (>=1)
    num_hidden = 1

    # pic_row*pic_col サイズの画像を入力とする
    pic_row = 128
    pic_col = 128
    n_in = pic_row * pic_col * 3 # 3: color. 白黒画像で学習させたいならこいつを1にする

    X_train, Y_train_cat = loadImg(train_dir, pic_row, pic_col)
    cat_dic = createCatDic(Y_train_cat)
    Y_train = to_categorical([cat_dic[cat] for cat in Y_train_cat], len(cat_dic))
    X_test , Y_test_cat = loadImg(test_dir, pic_row, pic_col)
    Y_test = to_categorical([cat_dic[cat] for cat in Y_test_cat], len(cat_dic))

    # カテゴリ数
    n_out = len(cat_dic)

    model = ModelFinder.getSimpleModel(n_in, n_hidden, num_hidden, n_out)
    model.compile(loss='categorical_crossentropy',
                  optimizer=SGD(lr=0.01),
                  metrics=['accuracy'])

    # モデル学習
    epochs = 10000
    batch_size = min(100, len(X_train))

    model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size)

    # 予測精度の評価
    loss_and_metrics = model.evaluate(X_test, Y_test)
    print(loss_and_metrics) # 2番目にaccuracyが入っている.

結果

27/27 [==============================] - 0s - loss: 0.5093 - acc: 0.7778
Epoch 9999/10000
27/27 [==============================] - 0s - loss: 0.5093 - acc: 0.7778
Epoch 10000/10000
27/27 [==============================] - 0s - loss: 0.5093 - acc: 0.7778
14/14 [==============================] - 0s
[1.4256515502929688, 0.3571428656578064]

学習データだとaccuracy: 81.15% テストデータだと35.71%と辛い感じ。

隠れ層の次元を増やしていったときの精度を見る。 隠れ層の数n_dim = 2

27/27 [==============================] - 0s - loss: 0.3749 - acc: 0.8519
Epoch 10000/10000
27/27 [==============================] - 0s - loss: 0.3749 - acc: 0.8519
14/14 [==============================] - 0s
[1.4254211187362671, 0.5]

隠れ層を2つにしたら50%に上昇。

Epoch 9999/10000
27/27 [==============================] - 0s - loss: 0.0981 - acc: 1.0000
Epoch 10000/10000
27/27 [==============================] - 0s - loss: 0.0981 - acc: 1.0000
14/14 [==============================] - 0s
[2.0662350654602051, 0.42857146263122559]

隠れ層を3つにしたら過学習したっぽいwwつらい

Epoch 9998/10000
27/27 [==============================] - 0s - loss: 0.6240 - acc: 0.7407
Epoch 9999/10000
27/27 [==============================] - 0s - loss: 0.6239 - acc: 0.7407
Epoch 10000/10000
27/27 [==============================] - 0s - loss: 0.6238 - acc: 0.7407
14/14 [==============================] - 0s
[1.3305658102035522, 0.21428573131561279]

隠れ層を4つにしたらさらに精度が落ちた。

考察

学習データの数が足りなさすぎると思われるので、回転させたりして画像を水増ししてみよう。 各画像(X_train, X_test)は255で割って正規化したほうがいいらしい。 勾配消失問題への対処としてReLUを使ってみるか。 画像認識の十八番CNNも試してみるか〜。

正直テストデータで精度90%ぐらい出てもらわないと困るぞぅ。

参考書籍

詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理~

詳解 ディープラーニング ~TensorFlow・Kerasによる時系列データ処理~

AOJ 0121 Seven Puzzle

概要

7 パズル | Aizu Online Judge

解法

7パズルは解法が7!通りしかないので事前に計算しておく。幅優先探索で全列挙すればOK。 ただし、7パズルの右上マス、右下マスには注意(※右上マスには右のマスという概念がない)

ソース

import java.io.IOException;
import java.util.*;

class Main {
    public static void main(String[] args) throws IOException {
        Scanner in = new Scanner(System.in);
        Map<String, Integer> ans = new HashMap<>();
        create_ans(ans);
        while(in.hasNext()) {
            String[] map = in.nextLine().split(" ");
            System.out.println(ans.get(""+new Node(0, 0, map).getStrId()));
        }
    }

    public static void create_ans(Map<String, Integer> ans) {
        String[] map = {"0","1","2","3","4","5","6","7"};
        Set<Integer> s = new HashSet<>();
        Queue<Node> q = new PriorityQueue<>();
        int z_pos = 0;
        for(int i=0;i<8;i++) {
            if(map[i].equals("0")) {
                z_pos = i;break;
            }
        }
        q.offer(new Node(z_pos, 0, map));
        while(!q.isEmpty()) {
            Node node = q.poll();
            int id = node.getId();
            s.add(id);
            ans.put(node.getStrId(), node.num);

            // 上下左右を見て、入れ替えられるなら入れ替える
            int di[] = {-4, +4, -1, +1};
            for(int i=0;i<4;i++) {
                int zero = node.zero_pos + di[i];
                // ありえない
                if(zero < 0 || zero > 7) continue;
                if(node.zero_pos == 3 && i == 3) continue;
                if(node.zero_pos == 4 && i == 2) continue;

                // 入れ替える
                Node newNode = new Node(zero, node.num+1, node.map);
                swap(newNode.map, node.zero_pos, zero);
                if(s.contains(newNode.getId())) continue;
                newNode.addSteps(node.getSteps());
                newNode.addStep(newNode);

                q.offer(newNode);
                s.add(newNode.getId());
            }
        }
    }

    public static void swap(String[] map, int p1, int p2) {
        String buf = map[p1];
        map[p1] = map[p2];
        map[p2] = buf;
    }
}

class Node implements  Comparable<Node>{
    int zero_pos = 0;
    int num;
    String[] map;
    List<Node> steps = new ArrayList<>();
    Node(int zero_pos, int num, String[] map) {
        this.zero_pos = zero_pos;
        this.num = num;
        this.map = map.clone();
    }

    public void addStep(Node node) {
        this.steps.add(node);
    }

    public void addSteps(List<Node> steps) {
        this.steps = new ArrayList<Node>(steps);
    }

    public List<Node> getSteps() {
        return this.steps;
    }

    public int getId() {
        int id = 0;
        for(int i=0;i<map.length;i++) {
            id = id * 10 + Integer.parseInt(map[i]);
        }
        return id;
    }

    public String getStrId() {
        String id = "";
        for(int i=0;i<map.length;i++) {
            id = id + map[i];
        }
        return id;
    }

    @Override
    public int compareTo(Node o) {
        return this.num - o.num;
    }
}

AOJ 0558: Cheese

概要

チーズ | Aizu Online Judge

解法

Sから1、1から2、2から3…の最短距離をそれぞれ求めて和を取る。 最短距離の計算は幅優先探索でOK

コード

import java.io.IOException;
import java.util.*;

class Main {
    public static void main(String[] args) throws IOException {
        Scanner in = new Scanner(System.in);
        int H, W, N;
        H = in.nextInt();
        W = in.nextInt();
        N = in.nextInt();

        String[] map = new String[H+3];
        String str = "";
        for(int i=0;i<W+2;i++) {
            str += "X";
        }
        map[0] = str;

        in.nextLine();
        for(int i=0;i<H;i++) {
            String s = in.nextLine();
            map[i+1] = "X" + s + "X";
        }
        map[H+1] = str;

        solve(map, H, W, N);
    }

    public static void solve(String[] ma, int H, int W, int N) {
        int ans = 0;
        int posx[] = new int[N+1];
        int posy[] = new int[N+1];

        // まずSの場所を探す
        for(int i=1;i<=H;i++) {
            int pos = ma[i].indexOf("S");
            if(pos >= 0) {
                posx[0] = i;
                posy[0] = pos;
                ma[i].replace("S", "0");
                break;
            }
        }
        for(int i=1;i<=H;i++) {
            for(int j=1;j<=N;j++) {
                int pos = ma[i].indexOf(""+j);
                if(pos >= 0) {
                    posx[j] = i;
                    posy[j] = pos;
                }
            }
        }

        for(int i=1;i<=N;i++) {
            ans += bfs(ma, posx[i-1], posy[i-1], H, W, ""+i);
        }
        System.out.println(ans);
    }

    public static int bfs(String[]ma, int sh, int sw, int H, int W, String target) {
        // 幅優先探索でtargetという文字列までの最短距離を計算
        boolean[][] visited = new boolean[H+2][W+2];
        Queue<P> q = new PriorityQueue<>();
        q.offer(new P(sh, sw, 0)); // スタート地点を入れる

        while(!q.isEmpty()) {
            P p = q.poll();
            visited[p.sh][p.sw] = true;

            // 数字が見つかった
            if((""+ma[p.sh].charAt(p.sw)).equals(target)) {
                return p.dist;
            }

            int dh[] = {0, 0, +1, -1};
            int dw[] = {1,-1,  0,  0};
            for(int i=0;i<4;i++) {
                int h = p.sh + dh[i];
                int w = p.sw + dw[i];
                if((""+ma[h].charAt(w)).equals("X") || visited[h][w]) continue;
                visited[h][w] = true;
                q.offer(new P(h, w, p.dist + 1));
            }
        }

        assert (false);
        return -1;
    }
}

class P  implements Comparable<P> {
    int sh;
    int sw;
    int dist;
    P(int sh, int sw, int dist) {
        this.sh = sh;
        this.sw = sw;
        this.dist = dist;
    }

    @Override
    public int compareTo(P p) {
        return this.dist - p.dist;
    }
}

AOJ 0118 Property Distribution

概要

財産分配 | Aizu Online Judge

解法

パット見DFSしたくなるが、Runtime Error.

蟻本の練習問題には深さ優先探索で解けと書いてあったが、再帰だと最悪W*H回のDFSになりRuntime ErrorとなりACされないためstackを使った。

ソース

import java.io.IOException;
import java.util.*;

class Main {
    public static void main(String[] args) throws IOException {
        Scanner in = new Scanner(System.in);

        while (true) {
            int H = in.nextInt();
            int W = in.nextInt();
            if (H == 0 && W == 0) return;

            in.nextLine();

            String[] map = new String[H + 200];
            String str = "";
            for (int i = 0; i < W + 2; i++) {
                str += ".";
            }
            map[0] = str;
            for (int i = 1; i <= H; i++) {
                String s = in.nextLine();
                map[i] = "." + s + ".";
            }
            map[H + 1] = str;

            solve(H, W, map);
        }
    }

    public static void print(String[] map, int H, int W) {
        for (int i = 0; i < H + 2; i++) {
            for (int j = 0; j < W + 2; j++) {
                System.out.print(map[i].charAt(j));
            }
            System.out.println("");
        }
    }

    public static void solve(int H, int W, String[] map) {
//     print(map, H, W);
        boolean visited[][] = new boolean[H + 200][W + 200];
        int ans = 0;
        for (int i = 1; i <= H; i++) {
            for (int j = 1; j <= W; j++) {
                if (!visited[i][j]) {
                    dfs(visited, map, "" + map[i].charAt(j), i, j);
                    ans++;
                }
            }
        }
        System.out.println(ans);
    }

    public static void dfs(boolean[][] visited, String[] map, String key, int h, int w) {
        Stack<P> sta = new Stack<>();
        sta.push(new P(h, w));
        while (sta.isEmpty() == false) {
            P p = sta.pop();
            h = p.h;
            w = p.w;
            visited[h][w] = true;
            int di[] = {0, 0, -1, +1};
            int dj[] = {+1, -1, 0, 0};
            for (int i = 0; i < 4; i++) {
                int ni = h + di[i];
                int nj = w + dj[i];
                if (visited[ni][nj] == false && ("" + map[ni].charAt(nj)).equals(key)) {
                    sta.push(new P(ni, nj));
//                 dfs(visited, map, key, ni, nj);
                }
            }
        }
    }
}

class P {
    int h;
    int w;

    P(int h, int w) {
        this.h = h;
        this.w = w;
    }
}

AOJ 0042

  • 概要

泥棒 | Aizu Online Judge

  • 解法

DP。ナップザック問題そのまま。

dp[i+1][j]:=i個まで選んだときの、重さがj以下における価値の総和の最大値とすると、

  • dp[0][j] = 0
  • dp[i][j] = max(dp[i-1][j-w[i]] + v[i], dp[i-1][j]) ただし、j-w[i] >= 0

  • コード

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.*;

class Main {
    public static void main(String[] args) throws IOException {
        Scanner in = new Scanner(System.in);
        int cnt = 1;
        while(true) {
            int W = in.nextInt();
            if(W == 0) return;
            int N = in.nextInt();
            int v[] = new int[N+1];
            int w[] = new int[N+1];
            in.nextLine();
            for(int i=1;i<=N;i++) {
                String[] strs = in.nextLine().split(",");
                v[i] = Integer.parseInt(strs[0]);
                w[i] = Integer.parseInt(strs[1]);
            }
            System.out.println("Case "+cnt+":");
            solve(W, N, v, w);
            cnt++;
        }
    }

    public static void solve(int W, int N, int[] v, int[] w) {
        int[][] dp = new int[N+2][W+2];
        for(int i=0;i<N+2;i++) dp[0][i] = 0;

        for(int i=1;i<=N;i++) {
            for(int j=1;j<=W;j++) {
                dp[i][j] = dp[i-1][j];
                if(j-w[i] >= 0) {
                    dp[i][j] = Math.max(dp[i - 1][j - w[i]] + v[i], dp[i - 1][j]);
                }
            }
        }
        int ans_V = 0;
        int ans_W = Integer.MAX_VALUE / 2;
        for(int i=0;i<=W;i++) {
            if(ans_V < dp[N][i]) {
                ans_V = dp[N][i];
                ans_W = i;
            }
        }
        System.out.println(ans_V);
        System.out.println(ans_W);
    }
}

AOJ 0035 Is it Convex?

  • 問題

凹みの検知 | Aizu Online Judge

2次元座標上に四角形が与えられる。凹みがなければ YES、凹みがあれば NOと出力せよ

  • 解法

外積を取って符号を見る

vec(A, B), vec(A, C)の外積を取る→符号を覚えておく

vec(B, C), vec(B, D)の外積を取り、符号が変わっていたら凹みあり、符号が同じなら以下vec(D,A),vec(D,B)の外積を取るまで続ける。

符号が最後まで同じなら凹みなし。

  • プログラム
import java.io.IOException;
import java.util.*;

class Point {
    double x;
    double y;
    Point(double x, double y) {
        this.x = x;
        this.y = y;
    }

    public double getX() {
        return x;
    }

    public void setX(double x) {
        this.x = x;
    }

    public double getY() {
        return y;
    }

    public void setY(double y) {
        this.y = y;
    }
}

class Vec {
    Point a;
    Point b;
    Vec(Point a, Point b) {
        this.a = a;
        this.b = b;
    }
}

class MathUtil {
    /**
    * 外積を求める。
    *     Q
    *   /
    *   /
    *  C -------- P
    *
    * @param p
    * @param c
    * @param q
    * @return >0: 左回り、 <0:右回り、 0:Q,C,Pは一直線上にある
     */
    public static double outProd(Point c, Point p, Point q) {
        return (p.getX() - c.getX()) * (q.getY() - c.getY()) - (p.getY() - c.getY()) * (q.getX() - c.getX());
    }

    public static boolean isZero(long EPS, double val) {
        return Math.abs(val) <= EPS;
    }
}

/**
 */
class Main {
    final static int N = 10;

    public static void main(String[] args) throws IOException {
        Scanner in = new Scanner(System.in);
        while (in.hasNext()) {
            String[] strArr = in.nextLine().split(",");
            double[] intArr = new double[strArr.length];
            for(int i=0;i<intArr.length;i++) {
                intArr[i] = Double.parseDouble(strArr[i]);
            }
            List<Point> points = new ArrayList<>();
            for(int i=0;i<intArr.length;i+=2) {
                double a = intArr[i];
                double b = intArr[i+1];
                points.add(new Point(a, b));
            }
            solve(points);
        }
    }

    public static void solve(List<Point> pts) {
        double prev = 0;
        int n =pts.size();
        for(int i=0;i<pts.size();i++) {
            double ret = MathUtil.outProd(pts.get(i%n), pts.get((i+1)%n), pts.get((i+2)%n));
            if(MathUtil.isZero((long)(1e-9), ret)) {
                continue;
            }
            if(ret * prev < 0) {
                System.out.println("NO");
                return;
            }
            prev = ret;
        }
        System.out.println("YES");
    }
}