# -*- coding: utf-8 -*-
from ..core import Env, Tensor, np
from PIL import Image
[docs]class AtariBase(Env):
@property
def actions(self):
return self.env.unwrapped.get_action_meanings()
[docs] @staticmethod
def state_to_image(state):
return Image.fromarray(state)
[docs] @staticmethod
def to_gray(image):
return image.convert('L')
[docs] @staticmethod
def resize(image, width, height):
result = np.array(image.resize((width, height))).astype('float64')
if result.ndim == 2:
result = np.expand_dims(result, axis=0)
elif result.ndim == 3:
result = np.swapaxes(result, 0, 2)
else:
raise ValueError('[*] Unexpected dimension was given.')
return result
[docs] @staticmethod
def normalize(image):
return image / 255.0
[docs]class BreakOut(AtariBase):
''' BreakOut \n
Maximize your score in the Atari 2600 game Breakout.
Observation:
Gym Default:
Type: Box(210, 160, 3)
RGB image
Transformed:
(1, 84, 84)
BW image
Actions:
Discrete(4)
Num Action
0 no operation
1 fire
2 move right
3 move left
'''
def __init__(self, width=84, height=84):
super().__init__('Breakout-v0')
self.width = width
self.height = height
[docs]class BreakOutRAM(AtariBase):
''' BreakOutRAM \n
Maximize your score in the Atari 2600 game Breakout.
Observation:
Box(128,)
the RAM of the Atari machine
Actions:
Discrete(4)
Num Action
0 no operation
1 fire
2 move right
3 move left
'''
def __init__(self):
super().__init__('Breakout-ram-v0')
[docs]class Pong(AtariBase):
''' Pong \n
Maximize your score in the Atari 2600 game Pong.
Observation:
Gym Default:
Type: Box(210, 160, 3)
RGB image
Transformed:
(1, 84, 84)
BW image
Actions:
Discrete(6)
Num Action
0 no operation
1 fire
2 move right
3 move left
4 RIGHTFIRE
5 RIGHTFIRE
'''
def __init__(self, width=84, height=84):
super().__init__('Pong-v0')
self.width = width
self.height = height
[docs]class PongRAM(AtariBase):
''' Pong \n
Maximize your score in the Atari 2600 game Pong.
Observation:
Box(128,)
the RAM of the Atari machine
Actions:
Discrete(6)
Num Action
0 no operation
1 fire
2 move right
3 move left
4 RIGHTFIRE
5 RIGHTFIRE
'''
def __init__(self):
super().__init__('Pong-ram-v0')