强化学习实战 | 表格型Q-Learning玩井字棋(一)
时间:2021-12-08 作者:wsy950409
在 强化学习实战 | 自定义Gym环境之井子棋 中,我们构建了一个井字棋环境,并进行了测试。接下来我们可以使用各种强化学习方法训练agent出棋,其中比较简单的是Q学习,Q即Q(S, a),是状态动作价值,表示在状态s下执行动作a的未来收益的总和。Q学习的算法如下:
可以看到,当agent在状态S,执行了动作a之后,得到了环境给予的奖励R,并进入状态S\'。同时,选择最大的Q(S\', a),更新Q(S, a)。所谓表格型Q学习,就是构建一个Q(S, a)的表格,维护所有的状态动作价值。 一个很好的示例来自 Q学习玩Flappy Bird,随着游戏的不断进行,Q表格中记录的状态越来越多,状态动作价值也越来越准确,于是小鸟也飞得越来越好。
我们也要构建这样的Q表格,并希望通过Q_table[state][action] 的检索方式访问其储存的状态动作价值,我们可以用字典实现:
\'[1, 0, -1, 0, 0, 0, 1, -1, 0]\' | {\'(0,1)\':0, \'(1,0)\':0, \'(1,1)\':0, \'(1,2)\':0, \'(2,2)\':0} |
\'[0, 1, 0, -1, 0, 0, -1, 0, 1]\' | ...... |
在本文中我们要做到如下的目标:
- 改写 强化学习实战 | 自定义Gym环境之井子棋 中的测试代码,要更有逻辑,更能凸显强化学习中 agent 和环境的概念。
- agent 随机选择空格进行动作,每次动作前,更新Q表格:若表格中不存在当前状态,则将当前状态及其动作价值添加至Q表格中。
- 玩50000次游戏,查看Q表格中的状态数
步骤1:创建文件
在任意目录新建文件 Table QLearning play 域名
步骤2:创建类 Agent()
Agent() 类 需要有(1)随机落子的动作生成函数(2)Q表格(3)更新Q表格的函数,且新增表格中全部状态动作价值设为0。代码如下:
class Agent(): def __init__(self): 域名ble = {} def getEmptyPos(self, env_): # 返回空位的坐标 action_space = [] for i, row in enumerate(域名e): for j, one in enumerate(row): if one == 0: 域名nd((i,j)) return action_space def randomAction(self, env_, mark): # 随机选择空格动作 actions = 域名mptyPos(env_) action_pos = 域名ce(actions) action = {\'mark\':mark, \'pos\':action_pos} return action def updateQtable(self, env_): # 更新Q表格 state = 域名e if str(state) not in 域名ble: # 新增状态 域名ble[str(state)] = {} actions = 域名mptyPos(env_) for action in actions: 域名ble[str(state)][str(action)] = 0 # 新增的状态动作价值为0
步骤3:创建类 Game()
Game() 类需要有(1)是/否显示游戏过程、更改行动时间间隔的属性(2)开局随机先后手(3)切换行动方的函数(4)游戏结束时,可以新建游戏。代码如下:
class Game(): def __init__(self, env): 域名RVAL = 0 # 行动间隔 域名ER = False # 是否显示游戏过程 域名t = \'blue\' if 域名om() > 0.5 else \'red\' # 随机先后手 域名entMove = 域名t # 当前行动方 域名 = env 域名t = Agent() def switchMove(self): # 切换行动玩家 move = 域名entMove if move == \'blue\': 域名entMove = \'red\' elif move == \'red\': 域名entMove = \'blue\' def newGame(self): # 新建游戏 域名t = \'blue\' if 域名om() > 0.5 else \'red\' 域名entMove = 域名t 域名t() def run(self): # 玩一局游戏 域名t() # 在第一次step前要先重置环境 不然会报错 while True: if 域名entMove == \'blue\': 域名teQtable(域名) # 只记录蓝方视角下的局面 action = 域名omAction(域名, 域名entMove) state, reward, done, info = 域名(action) if 域名ER: 域名er() 域名chMove() 域名p(域名RVAL) if done: 域名ame() if 域名ER: 域名er() 域名p(域名RVAL) break
步骤4:测试
(1)玩一局游戏,显示Q表格,及Q表格中储存的状态数:
env = 域名(\'TicTacToeEnv-v0\') game = Game(env) for i in range(1): 域名() for state in 域名ble: print(state) for action in 域名ble[state]: print(action, \': \', 域名ble[state][action]) print(\'--------------\') print(\'dim of state: \', len(域名ble))
输出:
(2) 玩50000局游戏,查看Q表格中储存的状态数:
env = 域名(\'TicTacToeEnv-v0\') game = Game(env) for i in range(50000): 域名() print(\'dim of state: \', len(域名ble))
输出:
整体代码如下:
import gym import random import time # 查看所有已注册的环境 # from gym import envs # print(域名()) class Game(): def __init__(self, env): 域名RVAL = 0 # 行动间隔 域名ER = False # 是否显示游戏过程 域名t = \'blue\' if 域名om() > 0.5 else \'red\' # 随机先后手 域名entMove = 域名t 域名 = env 域名t = Agent() def switchMove(self): # 切换行动玩家 move = 域名entMove if move == \'blue\': 域名entMove = \'red\' elif move == \'red\': 域名entMove = \'blue\' def newGame(self): # 新建游戏 域名t = \'blue\' if 域名om() > 0.5 else \'red\' 域名entMove = 域名t 域名t() def run(self): # 玩一局游戏 域名t() # 在第一次step前要先重置环境 不然会报错 while True: if 域名entMove == \'blue\': 域名teQtable(域名) # 只记录蓝方视角下的局面 action = 域名omAction(域名, 域名entMove) state, reward, done, info = 域名(action) if 域名ER: 域名er() 域名chMove() 域名p(域名RVAL) if done: 域名ame() if 域名ER: 域名er() 域名p(域名RVAL) break class Agent(): def __init__(self): 域名ble = {} def getEmptyPos(self, env_): # 返回空位的坐标 action_space = [] for i, row in enumerate(域名e): for j, one in enumerate(row): if one == 0: 域名nd((i,j)) return action_space def randomAction(self, env_, mark): # 随机选择空格动作 actions = 域名mptyPos(env_) action_pos = 域名ce(actions) action = {\'mark\':mark, \'pos\':action_pos} return action def updateQtable(self, env_): state = 域名e if str(state) not in 域名ble: 域名ble[str(state)] = {} actions = 域名mptyPos(env_) for action in actions: 域名ble[str(state)][str(action)] = 0 env = 域名(\'TicTacToeEnv-v0\') game = Game(env) for i in range(1): 域名() for state in 域名ble: print(state) for action in 域名ble[state]: print(action, \': \', 域名ble[state][action]) print(\'--------------\') print(\'dim of state: \', len(域名ble))View Code