游戏程序 SnakeGame 属性 1 2 3 4 5 6 7 8 9 10 self.board_size = board_size self.grid_size = self.board_size ** 2 self.cell_size = 40 self.width = self.height = self.board_size * self.cell_size self.border_size = 20 self.display_width = self.width + 2 * self.border_size self.display_height = self.height + 2 * self.border_size + 40 self.silent_mode = silent_mode
1 2 3 4 5 6 7 self.snake = None self.non_snake = None self.direction = None self.score = 0 self.food = None self.seed_value = seed
Snake 的数据结构 是一个数组
1 2 3 self.snake = [(self.board_size // 2 + i, self.board_size // 2 ) for i in range (1 , -2 , -1 )] self.non_snake = set ([(row, col) for row in range (self.board_size) for col in range (self.board_size) if (row, col) not in self.snake])
(精)移动逻辑 1 2 3 4 5 6 7 8 9 10 if (row, col) == self.food: food_obtained = True self.score += 10 if not self.silent_mode: self.sound_eat.play()else : food_obtained = False self.non_snake.add(self.snake.pop())
数据与渲染分离 render 每 0.15s 调用一次
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 def render (self ): self.screen.fill((0 , 0 , 0 )) pygame.draw.rect(self.screen, (255 , 255 , 255 ), (self.border_size - 2 , self.border_size - 2 , self.width + 4 , self.height + 4 ), 2 ) self.draw_snake() if len (self.snake) < self.grid_size: r, c = self.food pygame.draw.rect(self.screen, (255 , 0 , 0 ), (c * self.cell_size + self.border_size, r * self.cell_size + self.border_size, self.cell_size, self.cell_size)) self.draw_score() pygame.display.flip() for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit()
绘制面板信息(由事件触发) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 def draw_score (self ): score_text = self.font.render(f"Score: {self.score} " , True , (255 , 255 , 255 )) self.screen.blit(score_text, (self.border_size, self.height + 2 * self.border_size))def draw_welcome_screen (self ): title_text = self.font.render("SNAKE GAME" , True , (255 , 255 , 255 )) start_button_text = "START" self.screen.fill((0 , 0 , 0 )) self.screen.blit(title_text, (self.display_width // 2 - title_text.get_width() // 2 , self.display_height // 4 )) self.draw_button_text(start_button_text, (self.display_width // 2 , self.display_height // 2 )) pygame.display.flip()def draw_game_over_screen (self ): game_over_text = self.font.render("GAME OVER" , True , (255 , 255 , 255 )) final_score_text = self.font.render(f"SCORE: {self.score} " , True , (255 , 255 , 255 )) retry_button_text = "RETRY" self.screen.fill((0 , 0 , 0 )) self.screen.blit(game_over_text, (self.display_width // 2 - game_over_text.get_width() // 2 , self.display_height // 4 )) self.screen.blit(final_score_text, (self.display_width // 2 - final_score_text.get_width() // 2 , self.display_height // 4 + final_score_text.get_height() + 10 )) self.draw_button_text(retry_button_text, (self.display_width // 2 , self.display_height // 2 )) pygame.display.flip()def draw_button_text (self, button_text_str, pos, hover_color=(255 , 255 , 255 ), normal_color=(100 , 100 , 100 ) ): mouse_pos = pygame.mouse.get_pos() button_text = self.font.render(button_text_str, True , normal_color) text_rect = button_text.get_rect(center=pos) if text_rect.collidepoint(mouse_pos): colored_text = self.font.render(button_text_str, True , hover_color) else : colored_text = self.font.render(button_text_str, True , normal_color) self.screen.blit(colored_text, text_rect)def draw_countdown (self, number ): countdown_text = self.font.render(str (number), True , (255 , 255 , 255 )) self.screen.blit(countdown_text, (self.display_width // 2 - countdown_text.get_width() // 2 , self.display_height // 2 - countdown_text.get_height() // 2 )) pygame.display.flip()
概念 pygame Pygame 是一个基于 Python 语言的游戏开发库,它提供了一系列用于游戏开发的工具和函数。使用 Pygame 可以轻松地创建 2D 游戏,包括游戏窗口、图像、声音、键盘鼠标输入等等。Pygame 是一个跨平台的游戏开发库,可以在 Windows、Linux、Mac 等多种操作系统上运行。
下面是一个简单的 Pygame 程序示例,用于创建一个窗口并显示一张图片:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 import pygame pygame.init() screen_width, screen_height = 640 , 480 screen = pygame.display.set_mode((screen_width, screen_height)) image = pygame.image.load("image.png" ) screen.blit(image, (0 , 0 )) pygame.display.flip()while True : for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() exit()
在这个示例中,我们首先使用pygame.init()初始化 Pygame,然后使用pygame.display.set_mode()函数创建一个指定大小的窗口。接着,我们使用pygame.image.load()函数加载一张图片,并使用screen.blit()函数将图片显示在窗口上。最后,我们使用pygame.display.flip()函数刷新屏幕,将窗口上的内容显示出来。在程序的最后,我们使用一个无限循环等待关闭窗口事件,当用户点击窗口的关闭按钮时,程序退出并关闭 Pygame。
这只是一个简单的示例,Pygame 还提供了很多其他的功能和工具,可以用于创建更加复杂的游戏。
训练模型:train_cnn.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 env = SubprocVecEnv([make_env(seed=s) for s in seed_set]) lr_schedule = linear_schedule(2.5e-4 , 2.5e-6 ) clip_range_schedule = linear_schedule(0.150 , 0.025 ) model = MaskablePPO( "CnnPolicy" , env, device="cuda" , verbose=1 , n_steps=2048 , batch_size=512 , n_epochs=4 , gamma=0.94 , learning_rate=lr_schedule, clip_range=clip_range_schedule, tensorboard_log=LOG_DIR )
概念 Proximal Policy Optimization (PPO) Proximal Policy Optimization (PPO,近端策略优化) 是一种流行的强化学习算法,用于训练策略神经网络来解决连续动作空间的强化学习问题。它是由 OpenAI 提出的一种基于策略梯度方法的改进算法,旨在提高样本利用率和收敛性能。
PPO 的设计目标是在保持样本有效利用的同时,通过限制策略更新的幅度来稳定训练过程。它通过两个核心概念来实现这一目标:Clipped Surrogate Objective 和 Trust Region Policy Optimization。
PPO 的工作原理如下:
采样数据:通过与环境进行交互,收集一批经验样本,包括状态、动作、奖励等。
计算优势估计:使用价值函数估计优势函数,衡量每个动作的相对优势。
更新策略网络:使用经验样本来更新策略网络,通过最大化一个被修剪的代理目标函数(Clipped Surrogate Objective)来调整策略,限制单次更新的幅度。
重复以上步骤:迭代进行多次采样、计算优势估计和更新策略网络的过程,直到达到预定的训练步数或收敛条件。
PPO 的关键优点是相对于其他策略梯度方法,它更稳定且对超参数不太敏感。它使用了一种通过限制策略更新的幅度来平衡策略改进和探索的方法,从而在保持样本有效利用的同时提供了更稳定的训练过程。
在实际应用中,PPO 被广泛用于各种连续动作空间的强化学习任务,如机器人控制、游戏玩法等。它具有易于实现、高效收敛和良好的性能等特点,使其成为一种常用的强化学习算法。
Baseline 在机器学习中,术语”baselines”通常指的是基准模型 或基准算法 。基准模型是指在解决特定问题或任务时被广泛接受和使用的基本模型或算法 ,用作其他新算法或模型性能比较的参考点 。
基准模型在模型训练和性能评估中具有重要作用。它们可以提供一个参考标准,用来衡量新算法或模型的改进效果 。通过与基准模型进行比较,可以评估新算法的相对性能、效率或准确性。
在强化学习中,Stable Baselines 是一个基准算法库,提供了一系列经典和最先进的强化学习算法的实现。这些算法被视为基准模型,用作强化学习领域的参考点。使用 Stable Baselines 库,研究人员和开发者可以将自己的算法与这些基准算法进行比较,评估其在不同环境和任务上的性能优劣。
基准模型的选择通常基于多年的研究和实践经验,并在特定领域或任务上被广泛认可和验证。它们通常是经过精心设计和调整的算法,可以提供一定程度上的性能保证,并且能够帮助其他研究者和从业者更好地理解和解决问题。
需要注意的是,基准模型并不一定是最优解或最先进的方法,但它们代表了当前领域中被广泛接受和使用的模型或算法。它们作为参考点,可以帮助研究人员和从业者评估和推动领域的进展。
Stable Baselines3 Stable Baselines3 是一个开源的强化学习库 ,用于实现和训练强化学习算法。它是 Stable Baselines 库的第三个版本,专门针对 PyTorch 用户进行了优化和更新。
Stable Baselines3 提供了一系列经典和最先进的强化学习算法的实现,包括:
基于值函数的算法:
DQN (Deep Q-Network)
A2C (Advantage Actor-Critic)
PPO (Proximal Policy Optimization)
SAC (Soft Actor-Critic)
TD3 (Twin Delayed DDPG)
基于策略梯度的算法:
TRPO (Trust Region Policy Optimization)
PPO (Proximal Policy Optimization)
SAC (Soft Actor-Critic)
Stable Baselines3 的特点包括:
简单易用:具有简洁一致的 API 接口,方便用户定义和训练强化学习模型。
可扩展性:支持多种环境,包括 Gym、PyBullet 等,并提供了丰富的功能扩展和环境封装器。
高性能:使用高效的算法和并行化方法,以加快训练速度并提高样本利用率。
支持多种平台:可在 CPU 和 GPU 上运行,并支持分布式训练。
Stable Baselines3 建立在强化学习理论和实践的基础上,为用户提供了一个方便、高效和可靠的工具来研究和解决强化学习问题。它被广泛应用于各种领域,如机器人控制、游戏玩法、自动驾驶等。
sb3_contrib sb3_contrib 是 Stable Baselines3(SB3)库的一个扩展模块,提供了一些额外的功能和算法实现。SB3 是一个用于实现和训练强化学习算法的 Python 库 ,它提供了一系列经典和最先进的强化学习算法,以及用于构建、训练和评估强化学习智能体的工具。
sb3_contrib模块扩展了 SB3 库,为用户提供了一些额外的功能和算法,这些功能和算法可能还没有正式纳入 SB3 的核心库中,或者是由社区贡献的实验性实现。
一些在sb3_contrib模块中可能包含的功能和算法包括:
强化学习算法的变种或改进实现
额外的环境封装器或功能增强器
与其他库或框架的集成扩展
实验性的新功能和实现
需要注意的是,sb3_contrib中的功能和算法可能不受官方支持,可能缺乏文档或测试,且可能在将来的版本中发生变化。因此,在使用sb3_contrib模块中的功能时,需要谨慎评估其稳定性、适用性和可靠性。
总之,sb3_contrib是 SB3 库的一个扩展模块,提供了一些额外的功能和算法,可供用户尝试和使用,但需要谨慎评估其适用性和稳定性。
sb3_contrib.MaskablePPO sb3_contrib.MaskablePPO 是基于 Stable Baselines3 (SB3) 库的一个扩展实现,它提供了一种在 Proximal Policy Optimization (PPO) 算法中使用动作掩码 (action mask) 的能力。
sb3_contrib.MaskablePPO 扩展了 SB3 库中的 PPO 算法,使其支持动作掩码的功能。它添加了一些额外的方法和功能,允许用户在训练和使用 PPO 算法时使用动作掩码。具体来说,sb3_contrib.MaskablePPO 在以下方面进行了扩展:
增加了 get_action_mask() 方法:用于获取当前状态下的动作掩码。
增加了 get_attribute_mask() 方法:用于获取当前状态下的其他属性掩码。
修改了 learn() 方法:使其能够在训练过程中使用动作掩码来限制可执行的动作。
修改了策略网络 (policy network):使其能够接受动作掩码作为输入,并在计算动作分布时考虑动作掩码。
通过使用 sb3_contrib.MaskablePPO,您可以为特定任务或环境定义动作掩码,从而在 PPO 算法中限制智能体的动作选择范围。这对于处理具有约束的动作空间或需要动态限制动作的任务非常有用。
参数说明 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 Proximal Policy Optimization algorithm (PPO) (clip version) with Invalid Action Masking. Based on the original Stable Baselines 3 implementation. Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html Background on Invalid Action Masking: https://arxiv.org/abs /2006.14171 :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str ) :param learning_rate: The learning rate, it can be a function of the current progress remaining (from 1 to 0 ) :param n_steps: The number of steps to run for each environment per update (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param clip_range: Clipping parameter, it can be a function of the current progress remaining (from 1 to 0 ). :param clip_range_vf: Clipping parameter for the value function, it can be a function of the current progress remaining (from 1 to 0 ). This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue By default, there is no limit on the kl div. :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None , no logging) :param policy_kwargs: additional arguments to be passed to the policy on creation :param verbose: the verbosity level: 0 no output, 1 info, 2 debug :param seed: Seed for the pseudo random generators :param device: Device (cpu, cuda, ...) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance
动作掩码 (action mask) PPO 是一种流行的强化学习算法,用于训练策略神经网络来解决连续动作空间的强化学习问题。然而,在某些情况下,特定的环境或任务可能会限制可执行的动作集合。动作掩码 (action mask) 可以用来表示在给定状态下哪些动作是有效的,从而限制智能体的选择范围 。
动作掩码(action mask)是在强化学习中用于限制智能体在特定状态下可执行的动作集合的一种技术。它是一个与动作空间对应的二进制数组或向量,其中每个元素表示对应动作的可执行性 。
在某些强化学习任务中,特定的环境或任务可能会限制智能体在某些状态下执行特定的动作。例如,对于一个机器人导航任务,某些位置可能无法执行特定的移动动作,或者某些对象可能只能执行特定的操作动作。这时,可以使用动作掩码来限制智能体在这些状态下选择的动作 ,以确保只选择有效的动作 。
动作掩码通常是一个与动作空间相同维度的二进制数组或向量。对于每个动作,掩码中的对应位置为 1 表示该动作是可执行的,为 0 表示该动作是不可执行的 。智能体在选择动作时,会参考动作掩码来限制其选择范围,只从可执行的动作中进行选择。
在训练过程中,动作掩码可以动态地根据环境的状态进行更新 ,以适应可能发生的限制变化。通过使用动作掩码,可以增强智能体的决策能力,提高任务的效率和性能。
需要注意的是,动作掩码只是对动作空间的限制,并不改变智能体的策略网络或价值函数的结构。它只是在动作选择阶段进行过滤,并不影响强化学习算法的其他方面。
以这里的 SnakeEnv 为例:
1 2 def get_action_mask (self ): return np.array([[self._check_action_validity(a) for a in range (self.action_space.n)]])
Linear Scheduler 一个线性调度器(linear scheduler)通常用于在一定时间范围内线性地调整某个参数的值 。它可以用于调整学习率、探索率或其他需要动态调整的参数 。
线性调度器通过在每个时间步骤或每个训练批次中计算参数的新值,以实现逐渐线性变化的效果。通常,线性调度器有一个起始值和一个目标值,以及一个指定时间或步骤的总数。根据当前的时间或步骤,线性调度器计算出一个介于起始值和目标值之间的新值。
下面是一个示例,演示如何使用线性调度器来调整学习率:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class LinearScheduler : def __init__ (self, start_value, end_value, total_steps ): self.start_value = start_value self.end_value = end_value self.total_steps = total_steps self.current_step = 0 def get_value (self ): if self.current_step >= self.total_steps: return self.end_value fraction = self.current_step / self.total_steps value = self.start_value + fraction * (self.end_value - self.start_value) return value def step (self ): self.current_step += 1
在上述代码中,LinearScheduler类接受起始值(start_value)、目标值(end_value)和总步骤数(total_steps)作为参数进行初始化。
get_value()方法根据当前步骤数计算并返回参数的新值。如果当前步骤数超过了总步骤数,那么返回目标值(线性变化已经完成)。否则,根据当前步骤数与总步骤数的比例,计算参数的新值。
step()方法用于将当前步骤数加 1,以在每个时间步骤中前进。
通过创建线性调度器对象并在每个训练步骤中调用get_value()方法来获取参数的新值,可以逐渐线性调整参数,从起始值变化到目标值。
(TODO)SnakeEnv 继承自 gym.Env
gym.Env gym.Env是 OpenAI Gym 库中定义的一个基类,用于创建强化学习环境 。它提供了一组标准的方法和属性,用于定义一个环境的行为和状态空间,以及与环境进行交互的方法。
gym.Env的子类必须实现以下几个方法:
reset(): 重置环境的状态并返回初始观测值。
step(action): 执行一个动作,并返回下一个观测值、奖励、是否终止以及其他相关的信息。
render(): 可选方法,用于可视化环境的当前状态。
close(): 可选方法,用于清理环境的资源。
此外,gym.Env还定义了一些属性,例如:
action_space: 表示动作空间的对象,描述了可用动作的范围和结构。
observation_space: 表示观测空间的对象,描述了观测值的范围和结构。
reward_range: 表示奖励范围的元组,定义了环境中可能的奖励值的最小值和最大值。
通过继承gym.Env类并实现所需的方法和属性,可以创建自定义的强化学习环境,并与其他 Gym 库中的算法和工具进行集成。
使用模型:test_cnn.py 可以看到,模型是用于预测下一步的 action:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 action, _ = model.predict(obs, action_masks=env.get_action_mask()) prev_mask = env.get_action_mask() prev_direction = env.game.direction num_step += 1 obs, reward, done, info = env.step(action)if done: if info["snake_size" ] == env.game.grid_size: print (f"You are BREATHTAKING! Victory reward: {reward:.4 f} ." ) else : last_action = ["UP" , "LEFT" , "RIGHT" , "DOWN" ][action] print (f"Gameover Penalty: {reward:.4 f} . Last action: {last_action} " )elif info["food_obtained" ]: print (f"Food obtained at step {num_step:04d} . Food Reward: {reward:.4 f} . Step Reward: {sum_step_reward:.4 f} " ) sum_step_reward = 0 else : sum_step_reward += reward
模型调用是这一行:
1 action, _ = model.predict(obs, action_masks=env.get_action_mask())
这段代码使用了一个名为model的模型来预测一个动作。下面是对代码的解释:
obs是表示当前观测值的变量。这个变量包含了环境的状态信息,可以是一个向量、图像或其他形式的数据。
env.get_action_mask()是调用了环境对象的get_action_mask()方法,用于获取动作掩码(action masks)。动作掩码是一个布尔数组,用于表示在当前状态下哪些动作是有效的。可能是因为某些动作在当前状态下是不可执行的,或者环境有一些限制条件。
model.predict(obs, action_masks=env.get_action_mask())是调用了model模型的predict方法,传入了obs作为输入和action_masks作为可选参数。predict方法用于对给定的输入进行预测,并返回一个预测的结果。
结果通过元组的形式进行赋值,action, _ = model.predict(obs, action_masks=env.get_action_mask())中的_表示不使用的临时变量,通常用于忽略不需要的返回值。
总的来说,这段代码的目的是使用model模型对当前观测值进行预测,并将预测的动作存储在action变量中。预测过程可能受到动作掩码的限制,以确保只选择有效的动作。
资料 源码:
GitHub - linyiLYi/snake-ai: An AI agent that beats the classic game "Snake".
视频:
https://www.bilibili.com/video/BV1DT411H7ph/