# -*- coding: utf-8 -*-
from ..core import *
from ..autograd import Tensor
from ..functions import huber_loss, amax, mse_loss
import random
import numpy
import gym
import matplotlib.pyplot as plt
from matplotlib import animation
from logging import getLogger
logger = getLogger('QualiaLogger').getChild('rl')
[docs]class BaseAgent(object):
'''BaseAgent \n
Base class for agents. Some methods needs to be over ridden.
Args:
actions (list): list of actions
model (Module): model network
'''
def __init__(self, actions, model):
self.actions = actions
self.eps = 1
self.model = model
self.target = model
self.update_target_model()
self.optim = None
self.episode_count = 0
[docs] @classmethod
def init(cls, env, model):
actions = env.action_space.n
return cls(actions, model)
[docs] def set_optim(self, optim, **kwargs):
self.optim = optim(self.model.params, **kwargs)
def __str__(self):
return str('{}'.format(self.__class__.__name__))
def __call__(self, observation, *args):
return self.policy(observation, *args)
[docs] def policy(self, observation, *args):
raise NotImplementedError
[docs] def save(self, filename):
self.model.save(filename)
[docs] def load(self, filename):
self.model.load(filename)
self.target.load_state_dict(self.model.state_dict())
[docs] def play(self, env, render=True, filename=None):
frames = []
state = env.reset()
done = False
steps = 0
episode_reward = []
while not done:
if render:
frames.append(env.render(mode='rgb_array'))
action = self.policy(state, eps=0.001)
next, reward, done, _ = env.step(action)
episode_reward.append(reward.data[0])
state = next
steps += 1
logger.info('[*] Episode end - steps: {} reward: {}'.format(steps, sum(episode_reward)))
if render:
env.close()
if filename is not None:
env.animate(frames, filename)
[docs] def get_train_signal(self, experience, gamma=0.9):
self.model.eval()
state, next_state, reward, action, done = experience
# get state action value
state_action_value = self.model(state).gather(1, action)
next_state_action_value = amax(self.model(next_state), axis=1)
next_state_action_value[done] = 0
target_action_value = reward + gamma * next_state_action_value
return state_action_value, target_action_value.detach()
[docs] def update(self, state_action_value, target_action_value, loss_func=mse_loss):
self.model.train()
loss = loss_func(state_action_value, target_action_value)
self.optim.zero_grad()
loss.backward()
self.optim.step()
return to_cpu(loss.data) if gpu else loss.data
[docs] def update_target_model(self):
self.target.load_state_dict(self.model.state_dict())
[docs]class ValueAgent(BaseAgent):
''' ValueAgent \n
Base class for value based agents. Some methods needs to be over ridden.
'''
[docs] def policy(self, observation, *args, eps=None):
# returns action as numpy array
if eps is None:
eps = max(0.5*(1/(self.episode_count+1)), 0.05)
if random.random() < eps:
return numpy.random.choice(self.actions)
else:
self.model.eval()
return numpy.argmax(self.model(observation.reshape(1,-1), *args).asnumpy())
[docs]class PolicyAgent(BaseAgent):
''' PolicyAgent \n
Base class for policy based agents. Some methods needs to be over ridden.
'''
[docs] def policy(self, observation, *args, eps=None):
# returns action as numpy array
if eps is None:
eps = max(0.5*(1/(self.episode_count+1)), 0.05)
if random.random() < eps:
return numpy.random.choice(self.actions)
else:
self.model.eval()
return numpy.random.choice(self.actions, p=self.model(observation.reshape(1,-1), *args).asnumpy())
[docs]class ActorCriticAgent(BaseAgent):
''' ActorCriticAgent \n
Base class for actor-critic based agents. Some methods needs to be over ridden.
Args:
actor (Module): actor network
critic (Module): critic network
'''
def __init__(self, actor, critic):
self.eps = 1
self.actor = actor
self.actor_target = actor
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic = critic
self.critic_target = critic
self.critic_target.load_state_dict(self.critic.state_dict())
self.actor_optim = None
self.critic_optim = None
self.episode_count = 0
[docs] @classmethod
def init(cls, env, actor, critic):
return cls(actor, critic)
[docs] def set_actor_optim(self, optim, **kwargs):
self.actor_optim = optim(self.actor.params, **kwargs)
[docs] def set_critic_optim(self, optim, **kwargs):
self.critic_optim = optim(self.critic.params, **kwargs)
[docs] def policy(self, observation, *args, eps=None):
return self.actor(observation).asnumpy()
[docs] def save(self, filename):
self.actor.save(filename+'_actor')
self.critic.save(filename+'_critic')
[docs] def load(self, filename):
self.actor.load(filename+'_actor')
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic.load(filename+'_critic')
self.critic_target.load_state_dict(self.critic.state_dict())
[docs] def load_actor(self, filename):
self.actor.load(filename+'_actor')
self.actor_target.load_state_dict(self.actor.state_dict())
[docs]class Env(object):
''' Env \n
Wrapper class of gym.env for reinforcement learning.
Args:
env (str): task name
'''
def __init__(self, env):
self.env = gym.make(env)
self.steps = 0
def __str__(self):
return str(self.__class__.__name__)
@property
def max_steps(self):
return self.env._max_episode_steps
@property
def action_space(self):
return self.env.action_space
@property
def observation_space(self):
return self.env.observation_space
[docs] def reset(self):
self.steps = 0
return self.state_transformer(self.env.reset())
[docs] def render(self, **kwargs):
return self.env.render(**kwargs)
[docs] def close(self):
self.env.close()
[docs] def step(self, action):
self.steps += 1
next_state, reward, done, info = self.env.step(action)
return self.state_transformer(next_state), self.reward_transformer(reward), done, info
[docs] def show(self, filename=None):
frames = []
try:
self.env.reset()
for _ in range(self.max_steps):
self.env.render()
_, _, done, _ = self.env.step(self.env.action_space.sample())
frames.append(self.env.render(mode='rgb_array'))
if done:
break
self.env.close()
if filename is not None:
self.animate(frames, filename)
except:
self.env.close()
raise Exception('[*] Exception occurred during the Env.show() process.')
[docs] def animate(self, frames, filename):
plt.clf()
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
result = plt.imshow(frames[0])
plt.axis('off')
video = animation.FuncAnimation(plt.gcf(), lambda t: result.set_data(frames[t]), frames=len(frames), interval=50)
video.save(filename+'.mp4')
plt.close()