Source code for qualia2.rl.agents.ddqn

# -*- coding: utf-8 -*- 
from ..rl_core import ValueAgent, np
from ..rl_util import Trainer

[docs]class DDQN(ValueAgent): '''DQN 2015 implementation\n This implementation uses double networks for learning. DQN class incopolates the model (Module) and the optim (Optimizer). The model learns with experience replay, which is implemented in update() method. ''' def __init__(self, eps, actions): super().__init__(eps, actions)
[docs] def get_train_signal(self, experience, gamma): self.model.eval() self.target.eval() state, next_state, reward, action, done = experience state_action_value = self.model(state).gather(1, action) action_next = np.argmax(self.model(next_state).data, axis=1).reshape(-1,1) next_state_action_value = self.target(next_state).gather(1, action_next) next_state_action_value[done] = 0 target_action_value = reward + gamma * next_state_action_value return state_action_value, target_action_value.detach()
[docs]class DDQNTrainer(Trainer): '''DDQNTrainer \n Args: memory (deque): replay memory object capacity (int): capacity of the memory batch (int): batch size for training gamma (int): gamma value target_update_interval (int): interval for updating target network ''' def __init__(self, memory, batch, capacity, gamma=0.99, target_update_interval=3): super().__init__(memory, batch, capacity, gamma) self.target_update_interval = target_update_interval
[docs] def after_episode(self, episode, steps, agent, loss, reward, filename=None): super().after_episode(episode, steps, agent, loss, reward, filename) if(episode%self.target_update_interval==0): agent.update_target_model()
[docs] def train(self, env, agent, episodes=200, render=True, filename=None): self.before_train(env, agent) self.train_routine(env, agent, episodes=episodes, render=render, filename=filename) self.after_train() return agent