だれでもできる強化学習の例をPythonで入門者向けに説明
強化学習を勉強しはじめたは良いものの、なんだか理論が多すぎて、
「それで結局何をどうすれば強化学習モデルを作れるの?」
「どんな風に強化学習を応用していいかわからない」
と思ってしまう強化学習入門者に向けて、
だれでもわかるように説明しました!
この記事を読むと、まず
- 簡単な強化学習モデルを構築することができるようになる
- 他の色々なことに応用できるようになる
ということができるようになります。
ちなみに僕は教師あり学習モデルに強化学習モデルを組み合わせることによって、精度を上げることに成功しました。具体的にいうと、RMSLEという評価指標を0.47から0.32まで改善することができました。
では、早速実装例を見てください。
タクシーのデータを使っています。
# ライブラリを読み込みます
import time
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.offline
from plotly.offline import init_notebook_mode, iplot_mpl
#強化学習の実装例です。Qテーブルを使います
class QLearningTable:
def __init__(self,actions, learning_rate=0.005, reward_decay=0.9, e_greedy=0.9):
self.actions = actions
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy
table_cols=np.arange(0, 361)
table_rows=np.arange(10, 61901, 10)
self.q_table = pd.DataFrame(np.zeros([len(table_rows), len(table_cols)]), columns=table_cols, index=table_rows)
def getIndex(self, aList, aValue):
refer_sr = pd.Series(aList,index=aList)
return (abs(refer_sr - aValue)).idxmin()
def getQTable(self):
return self.q_table
def getRate(self, aSeed):
tSeed = self.getIndex(self.getQTable().index, aSeed)
return self.getQTable().loc[:, 0.05:].idxmax(axis=1)[tSeed]
def learn(self, distance, houi, diff_prev_time):
print("==============")
print("distance", distance, "houi",houi, "diff_prev_time", diff_prev_time)
# 状態
aIndex = self.getIndex(self.q_table.index, distance)
# 以下で、あるdistanceにおけるhouiのq値を更新
for houi_column in self.q_table.columns:
q_target = 0
if houi_column == houi:
q_predict = self.q_table.loc[aIndex, houi_column]
if diff_prev_time <= datetime.timedelta(minutes=5):
q_target += 1
elif diff_prev_time > datetime.timedelta(minutes=5) and diff_prev_time <= datetime.timedelta(minutes=10):
q_target += 0
else:
q_target += -1
self.q_table.loc[aIndex, houi_column] += self.lr * (q_target - q_predict)
def __init__のパラメータは、とりあえず置いておきましょう。ここで頭の中に?が浮かびすぎて、後の実装に支障をきたしてしまわぬように。
重要なのは、learnメソッドのパラメータです。
ここでデータの説明を軽くしておきます。distance, houi, diff_prev_timeの3つがありますが、これは距離、方位、一個前のレコードとの時間差です。
前の2つはある駅を中心とした距離と、方位です。
この3つのデータを使って、タクシーがどこらへんに居る時に、すぐに客を捕まえられるのかを可視化しようとしています。
ちなみにデータは、実車(客が乗っている状態)と空車(客が乗っていない状態)が交互になるように整形したものを使っています。なので、diff_prev_timeは客を降ろしてから客を乗せるまでの時間(あるいは客を乗せてから下ろすまでの時間)になります。
整形するのはこんな感じで記述できます
data['prev_status'] = data.shift(1)['status']
data['prev_status'] = data['prev_status'].fillna(method='bfill')
data['change_status_flg'] = data[data['status'] == data['prev_status']]
data = data[data['change_status_flg'] == 1]
では、Qテーブルにデータを入れて、強化学習させます。
data_tmp = data.query('status == "実車"').copy()
RL = QLearningTable(actions=[])
# 全てのレコードをループ
for index in data_tmp.index:
RL3.learn(houi=data_tmp.loc[index, 'houi'],
distance=round(data_tmp.loc[index,'distance'],3),
diff_prev_time=data_tmp.loc[index,'diff_prev_time']
)
distanceをround関数で小数第三位までにしていますが、これはデータ量を少なくすることと、ある程度感覚を空けることで、見やすくするためです。
これによって、Qテーブルが出来上がりまして、ある駅を中心にどこらへんがすぐに捕まえやすいかがわかるのです。
ただ、これでは局所的に捕まえやすいところがわからないというデメリットもあるのです。
どういうことかというと、ある区間が全然捕まらなくて、その区間の端にすごく捕まりやすいところがあるとします。そうすると、その捕まりやすいはずの部分は、おそらく空車の時間が長いはずなので、Qテーブルでは低く評価されます。なので、局所的に捕まりやすいけれど、その周りが広範囲で捕まりにくいところを、このQテーブルは正しく判断できないのです。
とはいえ、タクシーで行きたいのは局所的に人が捕まりやすいところではなく、広範囲で人が捕まりやすいところのはずなので、ここでは気にしなくても良いということにします。
こういった感じでQテーブルを作り、あとはpythonの地図ライブラリであるfoliumなどで可視化すれば、はっきりとタクシーで捕まりやすい場所がわかるでしょう。
強化学習はとっつきにくいかもしれませんが、やってみると意外と簡単にできたりするので、ぜひ一度試してみてくださいね。