python3に慣れるための線形分類器実装に関するメモ
目的
python3に慣れるために単細胞パーセプトロンを実装します.
数式の解説等は無し.
環境
- PyCharm Community Edition 2018.1
- Python 3.6.5(venv)
- numpy 1.14.2
- matplotlib 2.2.2
とりあえず実装
import
tkを導入していない場合(?)はGUIを使わない場合でもエラーが出るので,matplotlib.use("Agg")
する.
from collections import namedtuple import random import numpy import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as pyplot
シグモイド関数
numpyに用意されていると思っていたら無いらしいので作る.
def sigmoid(x, alpha=1): return 1.0 / (1.0 + numpy.exp(-alpha*x))
学習用データ生成
正解の分類関数classify_answer
を作る.線形分類器を作りたいので,線形分類可能な想定解にする.
N
個のサンプルデータを作る.各サンプルデータはD=2
次元で,0.0..1.0
の範囲の値を取る.
整数の乱数はrandom.randint
なのに,実数の乱数はrandom.randfloat
ではなくrandom.uniform
.
inputData
の生成がappend
実装で回りくどい.もっと良い書き方がありそう.
Data = namedtuple("Data", "input answer") N = 100 D = 2 inputData = [] for i in range(N): p = [random.uniform(0.0, 1.0), random.uniform(0.0, 1.0)] inputData.append(Data(input=p, answer=classify_answer(p)))
初期パラメータ
numpy.array
を使って定義する.
weight = numpy.array([0.0, 1.0, 1.0])
学習
学習データを1つピックアップして,1つずつ学習させていく.ちなみに以前の記事はデータセットを丸ごと与えていた.
定数項([1] + training.input
の[1]
)を付けるのを忘れてしまいがち.
for loopCount in range(50): for training in inputData: trainingData = numpy.array([1] + training.input) trainingAnswer = training.answer predicted = sigmoid(numpy.dot(trainingData, weight)) a = -2*(trainingAnswer - predicted)*(predicted*(1-predicted)) weight -= a*trainingData
学習結果の確認
後置forよりメソッドチェーンの方が個人的には好きかな…
predicted = [1 if sigmoid(numpy.dot([1] + data.input, weight)) > 0.5 else 0 for data in inputData] correct = sum(1 if predicted[i] == inputData[i].answer else 0 for i in range(N)) print("correct:", correct) print("weight:", weight)
画像出力
plotTX = list(map(lambda data: data.input[0], filter(lambda data: sigmoid(numpy.dot([1] + data.input, weight)) >= 0.5, inputData))) plotTY = list(map(lambda data: data.input[1], filter(lambda data: sigmoid(numpy.dot([1] + data.input, weight)) >= 0.5, inputData))) plotFX = list(map(lambda data: data.input[0], filter(lambda data: sigmoid(numpy.dot([1] + data.input, weight)) < 0.5, inputData))) plotFY = list(map(lambda data: data.input[1], filter(lambda data: sigmoid(numpy.dot([1] + data.input, weight)) < 0.5, inputData))) print(plotTX) pyplot.plot(plotTX, plotTY, "o") pyplot.plot(plotFX, plotFY, "o") pyplot.savefig("result.png")
実行結果の確認
correct: 99 weight: [-11.781737883074062, 11.49195397228784, 11.9669277377479]
この様な画像が出力される
リファクタリング
クラスで書き直した.
from collections import namedtuple import random import numpy import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as pyplot def sigmoid(x, alpha=1): return 1.0 / (1.0 + numpy.exp(-alpha*x)) def classify_answer(param): return 1 if param[0] + param[1] > 1.0 else 0 class Perceptron: def __init__(self, d): self.D = D self.weight = numpy.array([1.0]*d) def predict(self, data): return sigmoid(numpy.dot(numpy.array([1]+data), self.weight)) def train(self, data, answer): x = numpy.array([1]+data) p = sigmoid(numpy.dot(x, self.weight)) a = -2 * (answer - p) * (p * (1 - p)) self.weight -= a * x Data = namedtuple("Data", "input answer") N = 100 D = 3 # 定数項も含めた次元の数 inputData = [] for i in range(N): p = [random.uniform(0.0, 1.0), random.uniform(0.0, 1.0)] inputData.append(Data(input=p, answer=classify_answer(p))) pe = Perceptron(D) for loopCount in range(50): for training in inputData: pe.train(training.input, training.answer) predicted = [1 if pe.predict(data.input) > 0.5 else 0 for data in inputData] correct = sum(1 if predicted[i] == inputData[i].answer else 0 for i in range(N)) print("correct:", correct) print("weight:", pe.weight) plotTX = list(map(lambda data: data.input[0], filter(lambda data: pe.predict(data.input) >= 0.5, inputData))) plotTY = list(map(lambda data: data.input[1], filter(lambda data: pe.predict(data.input) >= 0.5, inputData))) plotFX = list(map(lambda data: data.input[0], filter(lambda data: pe.predict(data.input) < 0.5, inputData))) plotFY = list(map(lambda data: data.input[1], filter(lambda data: pe.predict(data.input) < 0.5, inputData))) print(plotTX) pyplot.plot(plotTX, plotTY, "o") pyplot.plot(plotFX, plotFY, "o") pyplot.savefig("result.png")
参考サイト・資料
- Pyplot tutorial — Matplotlib 2.2.2 documentation
- とりあえず描く — matplotlib 1.0 documentation
- 電子情報通信工学シリーズ 学習とニューラルネットワーク 森北出版
よく見たら横浜駅の人だった.