TensorFlowのチュートリアルMNIST For ML Beginnersを試してみる。

プログラミングの世界では、最初に「Hello World」と表示される非常に小さいプログラムを書くことが伝統となっている。
機械学習では、この「Hello World」の代わりに、MNISTを実行するのが伝統のようだ。

MNISTは、手書きの数字の画像のデータセットである。

ここでの目的は、手書きの数字の画像がどのような数字かを予測するモデルを構築することである。

input_data.pyの取得

最初に、tensorflow/input_data.py at r0.8 · tensorflow/tensorflow · GitHubから、表示されているコードをコピーして、適当なテキストエディタにペーストして、input_data.pyとして保存しておく。

このinput_data.pyは、データをウェブ上から取得したり、データを扱いやすいよう加工してくれる。

MNISTの手書き画像データのダウンロード

続いて、MNISTの手書き画像データをダウンロードをする。
上記にも記載したが、ここでは、先に手動でダウンロードしておくことにする。

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burgesから、次のファイルをダウンロードする。

train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

チュートリアルの実行

適当にワークスペースを作成し、次のようなファイル、ディレクトリ構成にしておく。
下のコードをtutorial.pyとして保存する。


workspace
    |-MNIST_data
    |    |-train-images-idx3-ubyte.gz
    |    |-train-labels-idx1-ubyte.gz
    |    |-t10k-images-idx3-ubyte.gz
    |    |-t10k-labels-idx1-ubyte.gz
    |-input_data.py
    |-tutorial.py

MNIST_dataディレクトリ以下に、MNISTの手書き画像データがない場合は、input_data.read_data_sets関数内で、ダウンロードされる。


import tensorflow as tf
import input_data

# MNISTデータを読み込み
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 画像データ
x = tf.placeholder("float", [None, 784])

# モデルの重み
W = tf.Variable(tf.zeros([784, 10]))

# モデルのバイアス
b = tf.Variable(tf.zeros([10]))

# トレーニングデータxとモデルの重みWを乗算した後、モデルのバイアスbを足し、
# ソフトマックス回帰(ソフトマックス関数)を適用
y = tf.nn.softmax(tf.matmul(x, W) + b)

# 正解データ
y_ = tf.placeholder("float", [None, 10])

# 損失関数をクロスエントロピーとする
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

# 学習係数を0.01として、勾配降下アルゴリズムを使用して、
# クロスエントロピーを最小化する
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 変数の初期化
init = tf.initialize_all_variables()

# セッションの作成
sess = tf.Session()

# セッションの開始および初期化の実行
sess.run(init)

# トレーニングの開始
for i in range(1000):
    # トレーニングデータからランダムに100個抽出する
    batch_xs, batch_ys = mnist.train.next_batch(100)

    # 確率的勾配降下によりクロスエントロピーを最小化するよう重みを更新
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# 予測値と正解値を比較して、bool値(true or false)にする
# tf.argmax(y, 1)は、予測値の各行で、最大値となるインデックスを一つ返す
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

# bool値を0 or 1に変換して平均値をとる -> 正解率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# テストデータを与えて、テストデータの正解率の表示
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

このtutorial.pyを実行すると、次のように表示された。
これは、トレーニングデータから構築された予測モデルに、テストデータを当てはめたところ、およそ91%の正解率であるということである。
もちろん、その時その時で多少変動する。


$ python3 tutorial.py
・・・
0.9113

関連する記事

  • WindowsにRStudioをインストールする手順WindowsにRStudioをインストールする手順 WindowsにRStudioをインストールする手順についてお伝えいたします。 ファイルのダウンロード RStudioのインストールファイルをダウンロードするために、次のサイトに移動します。 RStudio – Open source and enterprise-ready professional software for […]
  • Python CaboChaを用いて係り受け構造を抽出する方法Python CaboChaを用いて係り受け構造を抽出する方法 Pythonと日本語係り受け解析器であるCaboChaを用いて係る語と受ける語のペアを抽出する方法をご紹介する。 環境:Ubuntu14.04 Pythonツールのインストール PythonからCaboChaを扱うために、CaboChaに付属しているPythonのsetup.pyをインストールする。 これはPython2系専用であることに注意する。 caboch […]
  • MINTELのGNPDから出力されたCSVの文字化けの対処法 ミンテルGNPD(世界新商品情報データベース)から出力されたCSVが文字化けしている場合の対処法を備忘録として残しておきます。 Windowsパソコンにダウンロードかつ英文のみ ここでの対応は、「café」のようなアキュート・アクセントなどが文字化けしている場合の対応になります。 ubuntu上で文字コードをUTF-8に変換するにはiconvコマンドを用いて、次 […]
  • R knitrで特定ページを横向きにしてPDF出力する方法R knitrで特定ページを横向きにしてPDF出力する方法 knitrでレポートをPDF出力する際に、ある特定ページだけ横向きにする方法をお伝えする。 まずは、事前準備として本体となるファイルの同一ディレクトリに「header.tex」として、次の内容を書き込んで保存しておく。 これは、このファイルが読み込まれた時に、「lscape」パッケージを読み込んで、 「\blandscape」「\enlandscape」と記述してある場 […]
  • 顧客満足度調査から重要な改善点を導く方法顧客満足度調査から重要な改善点を導く方法 顧客満足度調査とは、顧客に対して提供している商品やサービスに対して、顧客がどれだけ満足しているか、または不満を持っているか、満足している点はどこか、不満を持っている点はどこかなどをアンケートなどにより調査することだ。 顧客満足度調査の結果は、各項目ごとに平均値を出したり、棒グラフやレーダーチャートにすることが多い。 もし、あなたが顧客満足度調査の結果を見て、「ある […]
TensorFlow チュートリアルMNIST For Beginnersを試してみる

TensorFlow チュートリアルMNIST For Beginnersを試してみる」への2件のフィードバック

コメントは受け付けていません。