技術ブログ

Developers's blog

【Python】TensorFlow / SegmentationチュートリアルでU-netを学習させる

2019.12.05 田村 和樹
TensorFlow ニューラルネットワーク 機械学習
【Python】TensorFlow / SegmentationチュートリアルでU-netを学習させる

はじめに

前回の「画像セグメンテーションのためのU-net概要紹介」では画像のクラス分類のタスクを、画像のSegmentationのタスクにどう発展させるかを解説し、SegmentationのネットワークであるU-netの理論ついて簡単に解説しました。
今回はTensorFlowのSegmentationのチュートリアルを行いながら、実際にU-netを学習させてみたいと思います。
尚、本記事ではTensorflowの詳しい解説は行いません。
参考 : https://www.tensorflow.org/tutorials/images/segmentation

Segmentationとは

ある物体が画像内に含まれている時、画像のどこにあるのかを推定するタスクのことです。
言い換えると「画像のピクセルがそれぞれ何かを推定する」タスクのことです。
今回はOxford-IIIT Pet Datasetというデータセットを用いて学習を行い、ピクセルごとに以下のようなクラス分けを行います。

  • Class 0 : 動物のピクセル
  • Class 1 : 動物とその他の境界線のピクセル
  • Class 2 : その他のピクセル


目標は以下のような出力を得ることです。(左:入力画像 右:出力画像)
Alt text


ライブラリのインポート

import tensorflow as tf
import sys
from IPython.display import display
from IPython.display import HTML
from PIL import Image
# sys.modules['Image'] = Image 
from __future__ import absolute_import , division, print_function, unicode_literals
from tensorflow_examples.models.pix2pix import pix2pix

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from IPython.display import clear_output
import matplotlib.pyplot as plt


データの読み込み&可視化

dataset, info = tfds.load('oxford_iiit_pet:3.0.0', with_info=True)

def normalize(input_image,input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255
    input_mask -= 1
    return input_image,input_mask

@tf.function
def load_image_train(datapoint):
    input_image = tf.image.resize(datapoint['image'],(128,128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'],(128,128))

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        input_mask = tf.image.flip_left_right(input_mask)
    input_image,input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

def load_image_test(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

for image, mask in train.take(88):
    sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

Alt text


U-netの構築

U-netのencorder部分には今回のデータセットとは別のデータセットを学習したのMovileNetを用います。
MovileNetの重みは学習で更新しないように更新しておきます。
このようにすることU-netの精度を上げることができます。(転移学習)
decorder部分には未学習のpix2pixを用います。
pix2pixの重みは学習により更新されます。

OUTPUT_CHANNELS = 3

# encorder部分には学習済みのMovileNet
base_model = tf.keras.applications.MobileNetV2(input_shape=[128,128,3],include_top=False)
layer_names = [
    'block_1_expand_relu',
    'block_3_expand_relu',
    'block_6_expand_relu',
    'block_13_expand_relu',
    'block_16_project'
]
layers = [base_model.get_layer(name).output for name in layer_names]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
# MovileNetの重みは固定
down_stack.trainable = False

# decorder部分にはpix2pixを用いる
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels):

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same', activation='softmax')  #64x64 -> 128x128

    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    x = inputs

    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


U-netの学習

実際にU-netが学習していく様子を眺めてみましょう。
epochが進むにつれ、正しく予測できているのがわかります。

class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch,logs=None):
        clear_output(wait=True)
        ims.append([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...]))])
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
            ims.append([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...]))])
        ims.append([sample_image, sample_mask,
                 create_mask(model.predict(sample_image[tf.newaxis, ...]))])

ims = []
EPOCHS = 100
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])

import matplotlib.animation as animation

def make_animation(ims):
    %matplotlib nbagg
    fig, (ax0, ax1, ax2) = plt.subplots(1,3,figsize=(14.0, 8.0))
    ax0.axis('off')
    ax1.axis('off')
    ax2.axis('off')
    ax0.set_title('Input Image')
    ax1.set_title('True Mask')
    ax2.set_title('Predicted Mask')

    ims2 = []
    for epoch,im in enumerate(ims):   

        im0, = [ax0.imshow(tf.keras.preprocessing.image.array_to_img(im[0]))]
        im1, = [ax1.imshow(tf.keras.preprocessing.image.array_to_img(im[1]))]
        im2, = [ax2.imshow(tf.keras.preprocessing.image.array_to_img(im[2]))]
        ims2.append([im0,im1,im2])
    ani = animation.ArtistAnimation(fig, ims2, interval=50, repeat_delay=1000)
    return ani

学習の様子

学習の様子を見てみます。epochが進むごとに精度が増していることがわかります。

Alt text

ani1 = make_animation(ims)
HTML(ani1.to_jshtml())


おまけ

U-netにはskip-conectionという手法が使われています。
encorder部分で畳み込みをして失ってしまった画像内の位置情報を保持する役割を持ちます。
U-netにskip-conectionが無い場合も比較してみましょう。

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model_no_sc(output_channels):

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same', activation='softmax')  #64x64 -> 128x128

    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    x = inputs
    x = down_stack(x)
    x = x[-1]
    for up in up_stack:
        x = up(x)
    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

ims = []
EPOCHS = 100
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model = unet_model_no_sc(OUTPUT_CHANNELS)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])


学習の様子

こちらも学習の経過を眺めてみます。
skip-conectionがある場合と比べ、学習が遅いばかりか精度が悪いことがわかります。

Alt text


Twitter・Facebookで定期的に情報発信しています!

関連記事

【物体検出】SSD(Single Shot MultiBox Detector)の解説

概要 先日の勉強会にてインターン生の1人が物体検出について発表してくれました。これまで物体検出は学習済みのモデルを使うことが多く、仕組みを知る機会がなかったのでとても良い機会になりました。今回の記事では発表してくれた内容をシェアしていきたいと思います。 あくまで物体検出の入門ということで理論の深堀りや実装までは扱いませんが悪しからず。 物体検出とは ディープラーニングによる画像タスクといえば画像の分類タスクがよく挙げられます。例としては以下の犬の画像から犬

記事詳細
【物体検出】SSD(Single Shot MultiBox Detector)の解説
ニューラルネットワーク 物体検知
【論文】

概要 小説を丸ごと理解できるAIとしてReformerモデルが発表され話題になっています。今回はこのReforerモデルが発表された論文の解説を行います。 自然言語や音楽、動画などのSequentialデータを理解するには広範囲における文脈の依存関係を理解する必要があり困難なタスクです。"Attention is all you need"の論文で紹介されたTransformerモデルは広くこれらの分野で用いられ、優秀な結果を出しています。 例えば機械翻訳

記事詳細
【論文】"Reformer: The Efficient Transformer"の解説
ニューラルネットワーク 論文解説
【論文】

機械学習では、訓練データとテストデータの違いによって、一部のテストデータに対する精度が上がらないことがあります。 例えば、水辺の鳥と野原の鳥を分類するCUB(Caltech-UCSD Birds-200-2011)データセットに対する画像認識の問題が挙げられます。意図的にではありますが訓練データを、 水辺の鳥が写っている画像は背景が水辺のものが90%、野原のものが10% 野原の鳥が写っている画像は背景が水辺のものが10%、野原のものが90% となるように

記事詳細
【論文】"Distributionally Robust Neural Networks"の解説
ニューラルネットワーク 機械学習 論文解説
強力な物体検出M2Detで動画の判別する(google colaboratory)

はじめに この記事では物体検出に興味がある初学者向けに、最新技術をデモンストレーションを通して体感的に知ってもらうことを目的としています。今回紹介するのはAAAI19というカンファレンスにて精度と速度を高水準で叩き出した「M2Det」です。one-stage手法の中では最強モデル候補の一つとなっており、以下の図を見ても分かるようにYOLO,SSD,Refine-Net等と比較しても同程度の速度を保ちつつ、精度が上がっていることがわかります。 ※https:

記事詳細
強力な物体検出M2Detで動画の判別する(google colaboratory)
ニューラルネットワーク 機械学習 画像認識

お問い合わせはこちらから