Gradio: 簡易アプリ作成 (画像分類編)

便利ツール

 以前の記事では「Gradio」というpythonで簡単にアプリを作れるツールについて紹介していきました! htmlやJavaScriptなどの知識も要らず作れて便利なのですが,以前の記事では "Hello World"を表示させるところまでの紹介だったので,今回の記事では実際に計算機アプリを作りながら,流れをまとめていきます!!

Gradioとは?といったところに気になる方はこちらも合わせて読んでみてください!

スポンサーリンク

画像分類アプリ作り方

 今回,作る画像分類アプリはこのような形です!画像をアップロードすると,その画像に写っているものを分類してくれるアプリです.

 入力値として 画像をアップロードして送信ボタンを押すと右側のoutputに分類結果が表示されます!

アプリ作成の流れ

アプリ作成の大まかな流れは以下の通りです! google colabで実行するだけでもサーバーレスでアプリを作れてしまうのはとても便利なので,興味ある方は是非試してみてください!


流れ

0. 前準備) pip install gradio でインストール

1. 画像分類モデルの設定

2. Interface内で実行する関数(predict)を定義

3. Interface を定義

4. launch してアプリを構築

1. 画像分類モデルの設定

 画像分類を行うにあたって,モデルを作る必要があるのでその設定を行なっていきます.この記事では,PyTorch Hub からResNet18をダウンロードして使っていきます!(自分で作ったモデルを設定することも可能ですよ)

import torch

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
PyTorch
An open source machine learning framework that accelerates the path from research prototyping to production deployment.

2. Interface内で実行する関数(predict)を定義

 入力値となる画像を受け取り, 予測を返す関数を定義する必要があります. 予測値は, 辞書型で Keyをクラス名, Valueを信頼度で返すような形となっています.
今回の例ではクラス名についてはImageNetで使われているものをテキストファイルから読み込んでいます.

入力・出力
  • 入力 inp : 画像 (PIL Image)
  • 出力 confidences : 辞書型 (Key: クラス名, Value: 信頼度)
import requests
from PIL import Image
from torchvision import transforms

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def predict(inp):
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(1000)}    
  return confidences

3. Interface を定義

 Interface関数の引数にはアプリの実装に必要な情報をほぼ設定するような形になります.その中でも必ず必要なものは 実装する関数名, 入力値, 出力値です.
その他は必要に応じて加えるオプションになりますが,ここでは一つずつ説明していきます!

import gradio as gr
demo = gr.Interface(実装する関数, 入力値, 出力値, (例, タイトル, 説明))

(1) 実装する関数

 これは,今回の場合は上で定義した 「predict」となります

(2) 入力値

 入力画面の設定です.入力値としては画像を受け取るので,「gr.inputs.Image」とし, 今回は Pillowの形式で画像を処理するので type="pil"と設定します

inputs=gr.inputs.Image(type="pil")

(3) 出力値

 今回の場合は,入力画像に写っているものを予測し,出力することを目的としています.そこでgr.outputs.Labelを用いて,引数でnum_top_classes=3と設定することで 上位3つとなる予測候補を出力できるようになります!(この値を変えることで出力する候補数を変えられますよ!)

outputs=gr.outputs.Label(num_top_classes=3)

4. launch してアプリを構築

Interfaceを定義した後は,この1文の実行を行うだけでアプリをサーバーレスで使えるようになりますよ!

demo.launch()

サンプルコード

今回実装したコードを全部合わせてると以下のようになります!そのままコードを実行するだけで この計算機アプリ作れますよ!(gradio のライブラリインストールはお忘れずに!)

import gradio as gr
import torch
import requests
from torchvision import transforms

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def predict(inp):
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(1000)}    
  return confidences

demo = gr.Interface(fn=predict, 
             inputs=gr.inputs.Image(type="pil"),
             outputs=gr.outputs.Label(num_top_classes=3),
             )
             
demo.launch()

コメント

タイトルとURLをコピーしました