diff --git a/.github/workflows/action.yaml b/.github/workflows/action.yaml index 4b511d92..232c0c6a 100644 --- a/.github/workflows/action.yaml +++ b/.github/workflows/action.yaml @@ -13,6 +13,7 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] python-version: [3.7, 3.8, 3.9] + poetry-version: [1.1.5] steps: - name: Checkout uses: actions/checkout@v2 @@ -20,11 +21,53 @@ jobs: uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -r handyrl/envs/kaggle/requirements.txt + - uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + - run: | + poetry install - name: pytest run: | - python -m pytest tests + poetry run python -m pytest -v + lint: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + poetry-version: [1.1.5] + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + - run: | + poetry install + - run: | + poetry run pysen run lint + validate-requirements: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + poetry-version: [1.1.5] + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - uses: abatilo/actions-poetry@v2.0.0 + with: + poetry-version: ${{ matrix.poetry-version }} + - run: | + # If failed, do + # $ poetry export --without-hashes -f requirements.txt --output requirements.txt + # $ poetry export --without-hashes -f requirements.txt --output requirements-dev.txt --dev + diff <(poetry export --without-hashes -f requirements.txt | grep -v '^Warning:') requirements.txt + diff <(poetry export --without-hashes -f requirements.txt --dev | grep -v '^Warning:') requirements-dev.txt diff --git a/README.md b/README.md index 05c7eb64..8950fe96 100755 --- a/README.md +++ b/README.md @@ -53,7 +53,13 @@ pip3 install -r requirements.txt To use games of kaggle environments (e.g. Hungry Geese) you can install also additional dependencies. ``` -pip3 install -r handyrl/envs/kaggle/requirements.txt +pip3 install -r requirements-dev.txt +``` + +Or equivalently, you can use [poetry](https://python-poetry.org/). + +``` +poetry install ``` @@ -115,3 +121,35 @@ NOTE: Default opponent AI is random agent implemented in `evaluation.py`. You ca * [Month 1 Winner in Hungry Geese (Kaggle)](https://www.kaggle.com/c/hungry-geese/discussion/222941) * [The 5th solution in Google Research Football with Manchester City F.C. (Kaggle)](https://www.kaggle.com/c/google-football/discussion/203412) + + +## How to develop HandyRL + +### Lint + +```console +poetry run pysen run lint +``` + +You can fix some errors automatically: + +```console +poetry run pysen run foramt +``` + +### Test + +```console +poetry run pysen run pytest -v +``` + +### Tips + +If you do not use IDE, the following command helps you: + +```console +# At the first time, install `inotifywait`. +sudo apt-get install -y inotify-tools + +function COMMAND { clear; poetry run pysen run lint && poetry run pytest -v }; COMMAND; while true; do if inotifywait -r -e modify . 2>/dev/null | git check-ignore --stdin >/dev/null; then :; else COMMAND; fi; done +``` diff --git a/handyrl/agent.py b/handyrl/agent.py index a43a41af..984d4c96 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -24,32 +24,32 @@ def observe(self, env, player, show=False): class RuleBasedAgent(RandomAgent): def action(self, env, player, show=False): - if hasattr(env, 'rule_based_action'): + if hasattr(env, "rule_based_action"): return env.rule_based_action(player) else: return random.choice(env.legal_actions(player)) def view(env, player=None): - if hasattr(env, 'view'): + if hasattr(env, "view"): env.view(player=player) else: print(env) def view_transition(env): - if hasattr(env, 'view_transition'): + if hasattr(env, "view_transition"): env.view_transition() else: pass def print_outputs(env, prob, v): - if hasattr(env, 'print_outputs'): + if hasattr(env, "print_outputs"): env.print_outputs(prob, v) else: - print('v = %f' % v) - print('p = %s' % (prob * 1000).astype(int)) + print("v = %f" % v) + print("p = %s" % (prob * 1000).astype(int)) class Agent: @@ -65,14 +65,14 @@ def reset(self, env, show=False): def plan(self, obs): outputs = self.model.inference(obs, self.hidden) - self.hidden = outputs.pop('hidden', None) + self.hidden = outputs.pop("hidden", None) return outputs def action(self, env, player, show=False): outputs = self.plan(env.observation(player)) actions = env.legal_actions(player) - p = outputs['policy'] - v = outputs.get('value', None) + p = outputs["policy"] + v = outputs.get("value", None) mask = np.ones_like(p) mask[actions] = 0 p -= mask * 1e32 @@ -90,7 +90,7 @@ def action(self, env, player, show=False): def observe(self, env, player, show=False): if self.observation: outputs = self.plan(env.observation(player)) - v = outputs.get('value', None) + v = outputs.get("value", None) if show: view(env, player=player) if self.observation: @@ -106,7 +106,7 @@ def plan(self, obs): for i, model in enumerate(self.model): o = model.inference(obs, self.hidden[i]) for k, v in o: - if k == 'hidden': + if k == "hidden": self.hidden[i] = v else: outputs[k] = outputs.get(k, []) + [o] diff --git a/handyrl/connection.py b/handyrl/connection.py index 28288b5f..e84a2f8a 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -2,14 +2,14 @@ # Licensed under The MIT License [see LICENSE for details] import io -import time -import struct -import socket -import pickle -import threading -import queue import multiprocessing as mp import multiprocessing.connection as connection +import pickle +import queue +import socket +import struct +import threading +from typing import Any, List, Optional def send_recv(conn, sdata): @@ -45,7 +45,7 @@ def _recv(self, size): def recv(self): buf = self._recv(4) - size, = struct.unpack("!i", buf.getvalue()) + (size,) = struct.unpack("!i", buf.getvalue()) buf = self._recv(size) return pickle.loads(buf.getvalue()) @@ -72,11 +72,8 @@ def send(self, msg): def open_socket_connection(port, reuse=False): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt( - socket.SOL_SOCKET, socket.SO_REUSEADDR, - sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1 - ) - sock.bind(('', int(port))) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1) + sock.bind(("", int(port))) return sock @@ -99,7 +96,7 @@ def connect_socket_connection(host, port): try: sock.connect((host, int(port))) except ConnectionRefusedError: - print('failed to connect %s %d' % (host, port)) + print("failed to connect %s %d" % (host, port)) return PickledConnection(sock) @@ -165,7 +162,7 @@ def start(self): thread.start() def _sender(self): - print('start sender') + print("start sender") while not self.shutdown_flag: data = next(self.send_generator) while not self.shutdown_flag: @@ -175,10 +172,10 @@ def _sender(self): break except queue.Empty: pass - print('finished sender') + print("finished sender") def _receiver(self, index): - print('start receiver %d' % index) + print("start receiver %d" % index) conns = [conn for i, conn in enumerate(self.conns) if i % self.num_receivers == index] while not self.shutdown_flag: tmp_conns = connection.wait(conns) @@ -193,11 +190,13 @@ def _receiver(self, index): break except queue.Full: pass - print('finished receiver %d' % index) + print("finished receiver %d" % index) class QueueCommunicator: - def __init__(self, conns=[]): + def __init__(self, conns: Optional[List[Any]] = None): + conns = [] if conns is None else conns + self.input_queue = queue.Queue(maxsize=256) self.output_queue = queue.Queue(maxsize=256) self.conns = {} @@ -228,7 +227,7 @@ def add_connection(self, conn): self.conn_index += 1 def disconnect(self, conn): - print('disconnected') + print("disconnected") self.conns.pop(conn, None) def _send_thread(self): diff --git a/handyrl/environment.py b/handyrl/environment.py index f470e816..8d09b442 100755 --- a/handyrl/environment.py +++ b/handyrl/environment.py @@ -4,29 +4,29 @@ # game environment import importlib - +from typing import Any, Dict, Optional ENVS = { - 'TicTacToe': 'handyrl.envs.tictactoe', - 'Geister': 'handyrl.envs.geister', - 'ParallelTicTacToe': 'handyrl.envs.parallel_tictactoe', - 'HungryGeese': 'handyrl.envs.kaggle.hungry_geese', + "TicTacToe": "handyrl.envs.tictactoe", + "Geister": "handyrl.envs.geister", + "ParallelTicTacToe": "handyrl.envs.parallel_tictactoe", + "HungryGeese": "handyrl.envs.kaggle.hungry_geese", } def prepare_env(env_args): - env_name = env_args['env'] + env_name = env_args["env"] env_source = ENVS.get(env_name, env_name) env_module = importlib.import_module(env_source) if env_module is None: print("No environment %s" % env_name) - elif hasattr(env_module, 'prepare'): + elif hasattr(env_module, "prepare"): env_module.prepare() def make_env(env_args): - env_name = env_args['env'] + env_name = env_args["env"] env_source = ENVS.get(env_name, env_name) env_module = importlib.import_module(env_source) @@ -38,17 +38,18 @@ def make_env(env_args): # base class of Environment + class BaseEnvironment: - def __init__(self, args={}): + def __init__(self, args: Optional[Dict[Any, Any]] = None): pass def __str__(self): - return '' + return "" # # Should be defined in all games # - def reset(self, args={}): + def reset(self, args: Optional[Dict[Any, Any]]): raise NotImplementedError() # @@ -135,7 +136,7 @@ def str2action(self, s, player=None): # Should be defined if you use network battle mode # def diff_info(self, player=None): - return '' + return "" # # Should be defined if you use network battle mode diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index a0306d7d..156ccaea 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -3,8 +3,9 @@ # implementation of Geister -import random import itertools +import random +from typing import Any, Dict, Optional import numpy as np import torch @@ -30,20 +31,21 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias): out_channels=4 * self.hidden_dim, kernel_size=self.kernel_size, padding=self.padding, - bias=self.bias + bias=self.bias, ) def init_hidden(self, input_size, batch_size): if batch_size is None: # for inference - return tuple([ - np.zeros((self.hidden_dim, *input_size), dtype=np.float32), - np.zeros((self.hidden_dim, *input_size), dtype=np.float32) - ]) + return tuple( + [ + np.zeros((self.hidden_dim, *input_size), dtype=np.float32), + np.zeros((self.hidden_dim, *input_size), dtype=np.float32), + ] + ) else: # for training - return tuple([ - torch.zeros(*batch_size, self.hidden_dim, *input_size), - torch.zeros(*batch_size, self.hidden_dim, *input_size) - ]) + return tuple( + [torch.zeros(*batch_size, self.hidden_dim, *input_size), torch.zeros(*batch_size, self.hidden_dim, *input_size)] + ) def forward(self, input_tensor, cur_state): h_cur, c_cur = cur_state @@ -70,12 +72,9 @@ def __init__(self, num_layers, input_dim, hidden_dim, kernel_size=3, bias=True): blocks = [] for _ in range(self.num_layers): - blocks.append(ConvLSTMCell( - input_dim=input_dim, - hidden_dim=hidden_dim, - kernel_size=(kernel_size, kernel_size), - bias=bias - )) + blocks.append( + ConvLSTMCell(input_dim=input_dim, hidden_dim=hidden_dim, kernel_size=(kernel_size, kernel_size), bias=bias) + ) self.blocks = nn.ModuleList(blocks) def init_hidden(self, input_size, batch_size): @@ -149,7 +148,7 @@ def init_hidden(self, batch_size=None): return self.body.init_hidden(self.input_size[1:], batch_size) def forward(self, x, hidden): - b, s = x['board'], x['scalar'] + b, s = x["board"], x["scalar"] h_s = s.view(*s.size(), 1, 1).repeat(1, 1, 6, 6) h = torch.cat([h_s, b], -3) @@ -164,26 +163,23 @@ def forward(self, x, hidden): h_v = self.head_v(h) h_r = self.head_r(h) - return {'policy': h_p, 'value': torch.tanh(h_v), 'return': h_r, 'hidden': hidden} + return {"policy": h_p, "value": torch.tanh(h_v), "return": h_r, "hidden": hidden} class Environment(BaseEnvironment): - X, Y = 'ABCDEF', '123456' + X, Y = "ABCDEF", "123456" BLACK, WHITE = 0, 1 BLUE, RED = 0, 1 - C = 'BW' - T = 'BR' - P = {-1: '_', 0: 'B', 1: 'R', 2: 'b', 3: 'r', 4: '*'} + C = "BW" + T = "BR" + P = {-1: "_", 0: "B", 1: "R", 2: "b", 3: "r", 4: "*"} # original positions to set pieces OPOS = [ - ['B2', 'C2', 'D2', 'E2', 'B1', 'C1', 'D1', 'E1'], - ['E5', 'D5', 'C5', 'B5', 'E6', 'D6', 'C6', 'B6'], + ["B2", "C2", "D2", "E2", "B1", "C1", "D1", "E1"], + ["E5", "D5", "C5", "B5", "E6", "D6", "C6", "B6"], ] # goal positions - GPOS = np.array([ - [(-1, 5), (6, 5)], - [(-1, 0), (6, 0)] - ], dtype=np.int32) + GPOS = np.array([[(-1, 5), (6, 5)], [(-1, 0), (6, 0)]], dtype=np.int32) D = np.array([(-1, 0), (0, -1), (0, 1), (1, 0)], dtype=np.int32) OSEQ = list(itertools.combinations([i for i in range(8)], 4)) @@ -192,7 +188,9 @@ def __init__(self, args=None): super().__init__() self.reset() - def reset(self, args={}): + def reset(self, args: Optional[Dict[Any, Any]] = None): + args = {} if args is None else args + self.args = args self.board = -np.ones((6, 6), dtype=np.int32) # (x, y) -1 is empty self.color = self.BLACK @@ -264,10 +262,10 @@ def position2str(self, pos): if self.onboard(pos): return self.X[pos[0]] + self.Y[pos[1]] else: - return '**' + return "**" def str2position(self, s): - if s != '**': + if s != "**": return np.array((self.X.find(s[0]), self.Y.find(s[1])), dtype=np.int32) else: return None @@ -296,7 +294,7 @@ def action2to(self, a, c): def action2str(self, a, player): if a >= 4 * 6 * 6: - return 's' + str(a - 4 * 6 * 6) + return "s" + str(a - 4 * 6 * 6) c = player pos_from = self.action2from(a, c) @@ -304,7 +302,7 @@ def action2str(self, a, player): return self.position2str(pos_from) + self.position2str(pos_to) def str2action(self, s, player): - if s[0] == 's': + if s[0] == "s": return 4 * 6 * 6 + int(s[1:]) c = player @@ -316,36 +314,38 @@ def str2action(self, s, player): for g in self.GPOS[c]: if ((pos_from - g) ** 2).sum() == 1: diff = g - pos_from - for d, dd in enumerate(self.D): + for d_, dd in enumerate(self.D): if np.array_equal(dd, diff): + d = d_ break break else: # check action direction diff = pos_to - pos_from - for d, dd in enumerate(self.D): + for d_, dd in enumerate(self.D): if np.array_equal(dd, diff): + d = d_ break return self.fromdirection2action(pos_from, d, c) def record_string(self): - return ' '.join([self.action2str(a, i % 2) for i, a in enumerate(self.record)]) + return " ".join([self.action2str(a, i % 2) for i, a in enumerate(self.record)]) def position_string(self): poss = [self.position2str(pos) for pos in self.piece_position] - return ','.join(poss) + return ",".join(poss) def __str__(self): # output state def _piece(p): return p if p == -1 or self.layouts[self.piece2color(p)] >= 0 else 4 - s = ' ' + ' '.join(self.Y) + '\n' + s = " " + " ".join(self.Y) + "\n" for i in range(6): - s += self.X[i] + ' ' + ' '.join([self.P[_piece(self.board[i, j])] for j in range(6)]) + '\n' - s += 'color = ' + self.C[self.color] + '\n' - s += 'record = ' + self.record_string() + s += self.X[i] + " " + " ".join([self.P[_piece(self.board[i, j])] for j in range(6)]) + "\n" + s += "color = " + self.C[self.color] + "\n" + s += "record = " + self.record_string() return s def _set(self, layout): @@ -401,25 +401,25 @@ def diff_info(self, player): info = {} if len(self.record) == 0: if self.turn_count > -2: - info['set'] = self.layouts[played_color] if color == played_color else -1 + info["set"] = self.layouts[played_color] if color == played_color else -1 else: - info['move'] = self.action2str(self.record[-1], played_color) + info["move"] = self.action2str(self.record[-1], played_color) if color == played_color and self.captured_type is not None: - info['captured'] = self.T[self.captured_type] + info["captured"] = self.T[self.captured_type] return info def update(self, info, reset): if reset: self.args = {**self.args, **info} self.reset(info) - elif 'set' in info: - self._set(info['set']) - elif 'move' in info: - action = self.str2action(info['move'], self.color) - if 'captured' in info: + elif "set" in info: + self._set(info["set"]) + elif "move" in info: + action = self.str2action(info["move"], self.color) + if "captured" in info: # set color to captured piece pos_to = self.action2to(action, self.color) - t = self.T.index(info['captured']) + t = self.T.index(info["captured"]) piece = self.colortype2piece(self.opponent(self.color), t) self.board[pos_to[0], pos_to[1]] = piece self.play(action) @@ -474,7 +474,7 @@ def legal_actions(self, _=None): if self.turn_count < 0: return [4 * 6 * 6 + i for i in range(70)] actions = [] - for pos in self.piece_position[self.color*8:(self.color+1)*8]: + for pos in self.piece_position[self.color * 8 : (self.color + 1) * 8]: if pos[0] == -1: continue t = self.piece2type(self.board[pos[0], pos[1]]) @@ -498,50 +498,54 @@ def observation(self, player=None): color = self.color if turn_view else self.opponent(self.color) opponent = self.opponent(color) - nbcolor = self.piece_cnt[self.colortype2piece(color, self.BLUE)] - nrcolor = self.piece_cnt[self.colortype2piece(color, self.RED )] - nbopp = self.piece_cnt[self.colortype2piece(opponent, self.BLUE)] - nropp = self.piece_cnt[self.colortype2piece(opponent, self.RED )] - - s = np.array([ - 1 if color == self.BLACK else 0, # my color is black - 1 if turn_view else 0, # view point is turn player - # the number of remained pieces - *[(1 if nbcolor == i else 0) for i in range(1, 5)], - *[(1 if nrcolor == i else 0) for i in range(1, 5)], - *[(1 if nbopp == i else 0) for i in range(1, 5)], - *[(1 if nropp == i else 0) for i in range(1, 5)] - ]).astype(np.float32) - - blue_c = self.board == self.colortype2piece(color, self.BLUE) - red_c = self.board == self.colortype2piece(color, self.RED) + nbcolor = self.piece_cnt[self.colortype2piece(color, self.BLUE)] + nrcolor = self.piece_cnt[self.colortype2piece(color, self.RED)] + nbopp = self.piece_cnt[self.colortype2piece(opponent, self.BLUE)] + nropp = self.piece_cnt[self.colortype2piece(opponent, self.RED)] + + s = np.array( + [ + 1 if color == self.BLACK else 0, # my color is black + 1 if turn_view else 0, # view point is turn player + # the number of remained pieces + *[(1 if nbcolor == i else 0) for i in range(1, 5)], + *[(1 if nrcolor == i else 0) for i in range(1, 5)], + *[(1 if nbopp == i else 0) for i in range(1, 5)], + *[(1 if nropp == i else 0) for i in range(1, 5)], + ] + ).astype(np.float32) + + blue_c = self.board == self.colortype2piece(color, self.BLUE) + red_c = self.board == self.colortype2piece(color, self.RED) blue_o = self.board == self.colortype2piece(opponent, self.BLUE) - red_o = self.board == self.colortype2piece(opponent, self.RED) - - b = np.stack([ - # board zone - np.ones_like(self.board), - # my/opponent's all pieces - blue_c + red_c, - blue_o + red_o, - # my blue/red pieces - blue_c, - red_c, - # opponent's blue/red pieces - blue_o if player is None else np.zeros_like(self.board), - red_o if player is None else np.zeros_like(self.board) - ]).astype(np.float32) + red_o = self.board == self.colortype2piece(opponent, self.RED) + + b = np.stack( + [ + # board zone + np.ones_like(self.board), + # my/opponent's all pieces + blue_c + red_c, + blue_o + red_o, + # my blue/red pieces + blue_c, + red_c, + # opponent's blue/red pieces + blue_o if player is None else np.zeros_like(self.board), + red_o if player is None else np.zeros_like(self.board), + ] + ).astype(np.float32) if color == self.WHITE: b = np.rot90(b, k=2, axes=(1, 2)) - return {'scalar': s, 'board': b} + return {"scalar": s, "board": b} def net(self): return GeisterNet -if __name__ == '__main__': +if __name__ == "__main__": e = Environment() for _ in range(100): e.reset() diff --git a/handyrl/envs/kaggle/hungry_geese.py b/handyrl/envs/kaggle/hungry_geese.py index 76d5f090..a0339107 100644 --- a/handyrl/envs/kaggle/hungry_geese.py +++ b/handyrl/envs/kaggle/hungry_geese.py @@ -7,7 +7,7 @@ # wrapper of Hungry Geese environment from kaggle import random -import itertools +from typing import Any, Dict, Optional import numpy as np import torch @@ -28,8 +28,8 @@ def __init__(self, input_dim, output_dim, kernel_size, bn): self.bn = nn.BatchNorm2d(output_dim) if bn else None def forward(self, x): - h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3) - h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2) + h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3) + h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2) h = self.conv(h) h = self.bn(h) if self.bn is not None else h return h @@ -49,24 +49,28 @@ def forward(self, x, _=None): h = F.relu_(self.conv0(x)) for block in self.blocks: h = F.relu_(h + block(h)) - h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1) + h_head = (h * x[:, :1]).view(h.size(0), h.size(1), -1).sum(-1) h_avg = h.view(h.size(0), h.size(1), -1).mean(-1) p = self.head_p(h_head) v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1))) - return {'policy': p, 'value': v} + return {"policy": p, "value": v} class Environment(BaseEnvironment): - ACTION = ['NORTH', 'SOUTH', 'WEST', 'EAST'] + ACTION = ["NORTH", "SOUTH", "WEST", "EAST"] NUM_AGENTS = 4 - def __init__(self, args={}): + def __init__(self, args: Optional[Dict[Any, Any]] = None): + args = {} if args is None else args + super().__init__() self.env = make("hungry_geese") self.reset() - def reset(self, args={}): + def reset(self, args: Optional[Dict[Any, Any]] = None): + args = {} if args is None else args + obs = self.env.reset(num_agents=self.NUM_AGENTS) self.update((obs, {}), True) @@ -101,54 +105,54 @@ def direction(self, pos_from, pos_to): def __str__(self): # output state - obs = self.obs_list[-1][0]['observation'] - colors = ['\033[33m', '\033[34m', '\033[32m', '\033[31m'] - color_end = '\033[0m' + obs = self.obs_list[-1][0]["observation"] + colors = ["\033[33m", "\033[34m", "\033[32m", "\033[31m"] + color_end = "\033[0m" def check_cell(pos): - for i, geese in enumerate(obs['geese']): + for i, geese in enumerate(obs["geese"]): if pos in geese: if pos == geese[0]: - return i, 'h' + return i, "h" if pos == geese[-1]: - return i, 't' + return i, "t" index = geese.index(pos) pos_prev = geese[index - 1] if index > 0 else None pos_next = geese[index + 1] if index < len(geese) - 1 else None directions = [self.direction(pos, pos_prev), self.direction(pos, pos_next)] return i, directions - if pos in obs['food']: - return 'f' + if pos in obs["food"]: + return "f" return None def cell_string(cell): if cell is None: - return '.' - elif cell == 'f': - return 'f' + return "." + elif cell == "f": + return "f" else: index, directions = cell - if directions == 'h': - return colors[index] + '@' + color_end - elif directions == 't': - return colors[index] + '*' + color_end + if directions == "h": + return colors[index] + "@" + color_end + elif directions == "t": + return colors[index] + "*" + color_end elif max(directions) < 2: - return colors[index] + '|' + color_end + return colors[index] + "|" + color_end elif min(directions) >= 2: - return colors[index] + '-' + color_end + return colors[index] + "-" + color_end else: - return colors[index] + '+' + color_end + return colors[index] + "+" + color_end cell_status = [check_cell(pos) for pos in range(7 * 11)] - s = 'turn %d\n' % len(self.obs_list) + s = "turn %d\n" % len(self.obs_list) for x in range(7): for y in range(11): pos = x * 11 + y s += cell_string(cell_status[pos]) - s += '\n' - for i, geese in enumerate(obs['geese']): - s += colors[i] + str(len(geese) or '-') + color_end + ' ' + s += "\n" + for i, geese in enumerate(obs["geese"]): + s += colors[i] + str(len(geese) or "-") + color_end + " " return s def step(self, actions): @@ -161,19 +165,19 @@ def diff_info(self, _): def turns(self): # players to move - return [p for p in self.players() if self.obs_list[-1][p]['status'] == 'ACTIVE'] + return [p for p in self.players() if self.obs_list[-1][p]["status"] == "ACTIVE"] def terminal(self): # check whether terminal state or not for obs in self.obs_list[-1]: - if obs['status'] == 'ACTIVE': + if obs["status"] == "ACTIVE": return False return True def outcome(self): # return terminal outcomes # 1st: 1.0 2nd: 0.33 3rd: -0.33 4th: -1.00 - rewards = {o['observation']['index']: o['reward'] for o in self.obs_list[-1]} + rewards = {o["observation"]["index"]: o["reward"] for o in self.obs_list[-1]} outcomes = {p: 0 for p in self.players()} for p, r in rewards.items(): for pp, rr in rewards.items(): @@ -196,12 +200,13 @@ def players(self): return list(range(self.NUM_AGENTS)) def rule_based_action(self, player): - from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, GreedyAgent - action_map = {'N': Action.NORTH, 'S': Action.SOUTH, 'W': Action.WEST, 'E': Action.EAST} + from kaggle_environments.envs.hungry_geese.hungry_geese import Action, Configuration, GreedyAgent, Observation + + action_map = {"N": Action.NORTH, "S": Action.SOUTH, "W": Action.WEST, "E": Action.EAST} - agent = GreedyAgent(Configuration({'rows': 7, 'columns': 11})) + agent = GreedyAgent(Configuration({"rows": 7, "columns": 11})) agent.last_action = action_map[self.ACTION[self.last_actions[player]][0]] if player in self.last_actions else None - obs = {**self.obs_list[-1][0]['observation'], **self.obs_list[-1][player]['observation']} + obs = {**self.obs_list[-1][0]["observation"], **self.obs_list[-1][player]["observation"]} action = agent(Observation(obs)) return self.ACTION.index(action) @@ -213,9 +218,9 @@ def observation(self, player=None): player = 0 b = np.zeros((self.NUM_AGENTS * 4 + 1, 7 * 11), dtype=np.float32) - obs = self.obs_list[-1][0]['observation'] + obs = self.obs_list[-1][0]["observation"] - for p, geese in enumerate(obs['geese']): + for p, geese in enumerate(obs["geese"]): # head position for pos in geese[:1]: b[0 + (p - player) % self.NUM_AGENTS, pos] = 1 @@ -228,19 +233,19 @@ def observation(self, player=None): # previous head position if len(self.obs_list) > 1: - obs_prev = self.obs_list[-2][0]['observation'] - for p, geese in enumerate(obs_prev['geese']): + obs_prev = self.obs_list[-2][0]["observation"] + for p, geese in enumerate(obs_prev["geese"]): for pos in geese[:1]: b[12 + (p - player) % self.NUM_AGENTS, pos] = 1 # food - for pos in obs['food']: + for pos in obs["food"]: b[16, pos] = 1 return b.reshape(-1, 7, 11) -if __name__ == '__main__': +if __name__ == "__main__": e = Environment() for _ in range(100): e.reset() diff --git a/handyrl/envs/kaggle/requirements.txt b/handyrl/envs/kaggle/requirements.txt deleted file mode 100644 index 1d6c875a..00000000 --- a/handyrl/envs/kaggle/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -kaggle_environments -requests diff --git a/handyrl/envs/parallel_tictactoe.py b/handyrl/envs/parallel_tictactoe.py index b70e3aea..d9fb5003 100755 --- a/handyrl/envs/parallel_tictactoe.py +++ b/handyrl/envs/parallel_tictactoe.py @@ -12,9 +12,9 @@ class Environment(TicTacToe): def __str__(self): - s = ' ' + ' '.join(self.Y) + '\n' + s = " " + " ".join(self.Y) + "\n" for i in range(3): - s += self.X[i] + ' ' + ' '.join([self.C[self.board[i, j]] for j in range(3)]) + '\n' + s += self.X[i] + " " + " ".join([self.C[self.board[i, j]] for j in range(3)]) + "\n" return s def step(self, actions): @@ -29,10 +29,12 @@ def _step(self, action, selected_player): self.board[x, y] = selected_color # check winning condition - if self.board[x, :].sum() == 3 * selected_color \ - or self.board[:, y].sum() == 3 * selected_color \ - or (x == y and np.diag(self.board, k=0).sum() == 3 * selected_color) \ - or (x == 2 - y and np.diag(self.board[::-1, :], k=0).sum() == 3 * selected_color): + if ( + self.board[x, :].sum() == 3 * selected_color + or self.board[:, y].sum() == 3 * selected_color + or (x == y and np.diag(self.board, k=0).sum() == 3 * selected_color) + or (x == 2 - y and np.diag(self.board[::-1, :], k=0).sum() == 3 * selected_color) + ): self.win_color = selected_color self.record.append((selected_color, action)) @@ -48,7 +50,7 @@ def update(self, info, reset): self.reset() else: saction, scolor = info.split(":") - action, player = self.str2action(saction), 'OX'.index(scolor) + action, player = self.str2action(saction), "OX".index(scolor) self._step(action, player) def turn(self): @@ -58,7 +60,7 @@ def turns(self): return self.players() -if __name__ == '__main__': +if __name__ == "__main__": e = Environment() for _ in range(100): e.reset() diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index c6403b7f..4d05d421 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -3,7 +3,6 @@ # implementation of Tic-Tac-Toe -import copy import random import numpy as np @@ -19,10 +18,7 @@ def __init__(self, filters0, filters1, kernel_size, bn, bias=True): super().__init__() if bn: bias = False - self.conv = nn.Conv2d( - filters0, filters1, kernel_size, - stride=1, padding=kernel_size//2, bias=bias - ) + self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size // 2, bias=bias) self.bn = nn.BatchNorm2d(filters1) if bn else None def forward(self, x): @@ -66,13 +62,13 @@ def forward(self, x, hidden=None): h_p = self.head_p(h) h_v = self.head_v(h) - return {'policy': h_p, 'value': torch.tanh(h_v)} + return {"policy": h_p, "value": torch.tanh(h_v)} class Environment(BaseEnvironment): - X, Y = 'ABC', '123' + X, Y = "ABC", "123" BLACK, WHITE = 1, -1 - C = {0: '_', BLACK: 'O', WHITE: 'X'} + C = {0: "_", BLACK: "O", WHITE: "X"} def __init__(self, args=None): super().__init__() @@ -91,13 +87,13 @@ def str2action(self, s, _=None): return self.X.find(s[0]) * 3 + self.Y.find(s[1]) def record_string(self): - return ' '.join([self.action2str(a) for a in self.record]) + return " ".join([self.action2str(a) for a in self.record]) def __str__(self): - s = ' ' + ' '.join(self.Y) + '\n' + s = " " + " ".join(self.Y) + "\n" for i in range(3): - s += self.X[i] + ' ' + ' '.join([self.C[self.board[i, j]] for j in range(3)]) + '\n' - s += 'record = ' + self.record_string() + s += self.X[i] + " " + " ".join([self.C[self.board[i, j]] for j in range(3)]) + "\n" + s += "record = " + self.record_string() return s def play(self, action, _=None): @@ -107,10 +103,12 @@ def play(self, action, _=None): self.board[x, y] = self.color # check winning condition - win = self.board[x, :].sum() == 3 * self.color \ - or self.board[:, y].sum() == 3 * self.color \ - or (x == y and np.diag(self.board, k=0).sum() == 3 * self.color) \ + win = ( + self.board[x, :].sum() == 3 * self.color + or self.board[:, y].sum() == 3 * self.color + or (x == y and np.diag(self.board, k=0).sum() == 3 * self.color) or (x == 2 - y and np.diag(self.board[::-1, :], k=0).sum() == 3 * self.color) + ) if win: self.win_color = self.color @@ -164,15 +162,13 @@ def observation(self, player=None): # input feature for neural nets turn_view = player is None or player == self.turn() color = self.color if turn_view else -self.color - a = np.stack([ - np.ones_like(self.board) if turn_view else np.zeros_like(self.board), - self.board == color, - self.board == -color - ]).astype(np.float32) + a = np.stack( + [np.ones_like(self.board) if turn_view else np.zeros_like(self.board), self.board == color, self.board == -color] + ).astype(np.float32) return a -if __name__ == '__main__': +if __name__ == "__main__": e = Environment() for _ in range(100): e.reset() diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 9cbb11e8..c0551c39 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -3,15 +3,14 @@ # evaluation of policies or planning algorithms +import multiprocessing as mp import random import time -import multiprocessing as mp - -from .environment import prepare_env, make_env -from .connection import send_recv, accept_socket_connections, connect_socket_connection -from .agent import RandomAgent, RuleBasedAgent, Agent, EnsembleAgent, SoftAgent -from .agent import view, view_transition +from typing import Any, Dict, Optional +from .agent import Agent, RandomAgent, view_transition +from .connection import accept_socket_connections, connect_socket_connection, send_recv +from .environment import make_env, prepare_env network_match_port = 9876 @@ -25,18 +24,18 @@ def __init__(self, agent, env, conn): def run(self): while True: command, args = self.conn.recv() - if command == 'quit': + if command == "quit": break - elif command == 'outcome': - print('outcome = %f' % args[0]) + elif command == "outcome": + print("outcome = %f" % args[0]) elif hasattr(self.agent, command): ret = getattr(self.agent, command)(self.env, *args, show=True) - if command == 'action': + if command == "action": player = args[0] ret = self.env.action2str(ret, player) else: ret = getattr(self.env, command)(*args) - if command == 'update': + if command == "update": reset = args[1] if reset: self.agent.reset(self.env, show=True) @@ -49,27 +48,30 @@ def __init__(self, conn): self.conn = conn def update(self, data, reset): - return send_recv(self.conn, ('update', [data, reset])) + return send_recv(self.conn, ("update", [data, reset])) def outcome(self, outcome): - return send_recv(self.conn, ('outcome', [outcome])) + return send_recv(self.conn, ("outcome", [outcome])) def action(self, player): - return send_recv(self.conn, ('action', [player])) + return send_recv(self.conn, ("action", [player])) def observe(self, player): - return send_recv(self.conn, ('observe', [player])) + return send_recv(self.conn, ("observe", [player])) + + +def exec_match(env, agents, critic, show=False, game_args: Optional[Dict[Any, Any]] = None): + """ match with shared game environment """ + game_args = {} if game_args is None else game_args -def exec_match(env, agents, critic, show=False, game_args={}): - ''' match with shared game environment ''' if env.reset(game_args): return None for agent in agents.values(): agent.reset(env, show=show) while not env.terminal(): if show and critic is not None: - print('cv = ', critic.observe(env, None, show=False)[0]) + print("cv = ", critic.observe(env, None, show=False)[0]) turn_players = env.turns() actions = {} for p, agent in agents.items(): @@ -83,12 +85,15 @@ def exec_match(env, agents, critic, show=False, game_args={}): view_transition(env) outcome = env.outcome() if show: - print('final outcome = %s' % outcome) + print("final outcome = %s" % outcome) return outcome -def exec_network_match(env, network_agents, critic, show=False, game_args={}): - ''' match with divided game environment ''' +def exec_network_match(env, network_agents, critic, show=False, game_args: Optional[Dict[Any, Any]] = None): + """ match with divided game environment """ + + game_args = {} if game_args is None else game_args + if env.reset(game_args): return None for p, agent in network_agents.items(): @@ -96,7 +101,7 @@ def exec_network_match(env, network_agents, critic, show=False, game_args={}): agent.update(info, True) while not env.terminal(): if show and critic is not None: - print('cv = ', critic.observe(env, None, show=False)[0]) + print("cv = ", critic.observe(env, None, show=False)[0]) turn_players = env.turns() actions = {} for p, agent in network_agents.items(): @@ -128,12 +133,12 @@ def execute(self, models, args): if model is None: agents[p] = self.default_agent else: - agents[p] = Agent(model, self.args['observation']) + agents[p] = Agent(model, self.args["observation"]) outcome = exec_match(self.env, agents, None) if outcome is None: - print('None episode in evaluation!') + print("None episode in evaluation!") return None - return {'args': args, 'result': outcome} + return {"args": args, "result": outcome} def wp_func(results): @@ -146,13 +151,13 @@ def wp_func(results): def eval_process_mp_child(agents, critic, env_args, index, in_queue, out_queue, seed, show=False): random.seed(seed + index) - env = make_env({**env_args, 'id': index}) + env = make_env({**env_args, "id": index}) while True: args = in_queue.get() if args is None: break g, agent_ids, pat_idx, game_args = args - print('*** Game %d ***' % g) + print("*** Game %d ***" % g) agent_map = {env.players()[p]: agents[ai] for p, ai in enumerate(agent_ids)} if isinstance(list(agent_map.values())[0], NetworkAgent): outcome = exec_network_match(env, agent_map, critic, show=show, game_args=game_args) @@ -166,7 +171,7 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g in_queue, out_queue = mp.Queue(), mp.Queue() args_cnt = 0 total_results, result_map = [{} for _ in agents], [{} for _ in agents] - print('total games = %d' % (len(args_patterns) * num_games)) + print("total games = %d" % (len(args_patterns) * num_games)) time.sleep(0.1) for pat_idx, args in args_patterns.items(): for i in range(num_games): @@ -174,7 +179,7 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g # When playing two player game, # the number of games with first or second player is equalized. first_agent = 0 if i < (num_games + 1) // 2 else 1 - tmp_pat_idx, agent_ids = (pat_idx + '-F', [0, 1]) if first_agent == 0 else (pat_idx + '-S', [1, 0]) + tmp_pat_idx, agent_ids = (pat_idx + "-F", [0, 1]) if first_agent == 0 else (pat_idx + "-S", [1, 0]) else: tmp_pat_idx, agent_ids = pat_idx, random.sample(list(range(len(agents))), len(agents)) in_queue.put((args_cnt, agent_ids, tmp_pat_idx, args)) @@ -214,10 +219,12 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g total_results[agent_id][oc] = total_results[agent_id].get(oc, 0) + 1 for p, r_map in enumerate(result_map): - print('---agent %d---' % p) + print("---agent %d---" % p) for pat_idx, results in r_map.items(): print(pat_idx, {k: results[k] for k in sorted(results.keys(), reverse=True)}, wp_func(results)) - print('total', {k: total_results[p][k] for k in sorted(total_results[p].keys(), reverse=True)}, wp_func(total_results[p])) + print( + "total", {k: total_results[p][k] for k in sorted(total_results[p].keys(), reverse=True)}, wp_func(total_results[p]) + ) def network_match_acception(n, env_args, num_agents, port): @@ -235,17 +242,16 @@ def network_match_acception(n, env_args, num_agents, port): waiting_conns = waiting_conns[1:] conn.send(env_args) # send accept with environment arguments - agents_list = [ - [NetworkAgent(accepted_conns[i * num_agents + j]) for j in range(num_agents)] - for i in range(n) - ] + agents_list = [[NetworkAgent(accepted_conns[i * num_agents + j]) for j in range(num_agents)] for i in range(n)] return agents_list def get_model(env, model_path): import torch + from .model import ModelWrapper + model = env.net()() model.load_state_dict(torch.load(model_path)) model.eval() @@ -259,54 +265,54 @@ def client_mp_child(env_args, model_path, conn): def eval_main(args, argv): - env_args = args['env_args'] + env_args = args["env_args"] prepare_env(env_args) env = make_env(env_args) - model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth' + model_path = argv[0] if len(argv) >= 1 else "models/latest.pth" num_games = int(argv[1]) if len(argv) >= 2 else 100 num_process = int(argv[2]) if len(argv) >= 3 else 1 agent1 = Agent(get_model(env, model_path)) critic = None - print('%d process, %d games' % (num_process, num_games)) + print("%d process, %d games" % (num_process, num_games)) seed = random.randrange(1e8) - print('seed = %d' % seed) + print("seed = %d" % seed) agents = [agent1] + [RandomAgent() for _ in range(len(env.players()) - 1)] - evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed) + evaluate_mp(env, agents, critic, env_args, {"default": {}}, num_process, num_games, seed) def eval_server_main(args, argv): - print('network match server mode') - env_args = args['env_args'] + print("network match server mode") + env_args = args["env_args"] prepare_env(env_args) env = make_env(env_args) num_games = int(argv[0]) if len(argv) >= 1 else 100 num_process = int(argv[1]) if len(argv) >= 2 else 1 - print('%d process, %d games' % (num_process, num_games)) + print("%d process, %d games" % (num_process, num_games)) seed = random.randrange(1e8) - print('seed = %d' % seed) + print("seed = %d" % seed) - evaluate_mp(env, [None] * len(env.players()), None, env_args, {'default': {}}, num_process, num_games, seed) + evaluate_mp(env, [None] * len(env.players()), None, env_args, {"default": {}}, num_process, num_games, seed) def eval_client_main(args, argv): - print('network match client mode') + print("network match client mode") while True: try: - host = argv[1] if len(argv) >= 2 else 'localhost' + host = argv[1] if len(argv) >= 2 else "localhost" conn = connect_socket_connection(host, network_match_port) env_args = conn.recv() except EOFError: break - model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth' + model_path = argv[0] if len(argv) >= 1 else "models/latest.pth" mp.Process(target=client_mp_child, args=(env_args, model_path, conn)).start() conn.close() diff --git a/handyrl/generation.py b/handyrl/generation.py index 63b7e553..6ddbd2d7 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -3,9 +3,9 @@ # episode generation -import random import bz2 import pickle +import random import numpy as np @@ -29,42 +29,42 @@ def generate(self, models, args): return None while not self.env.terminal(): - moment_keys = ['observation', 'policy', 'action_mask', 'action', 'value', 'reward', 'return'] + moment_keys = ["observation", "policy", "action_mask", "action", "value", "reward", "return"] moment = {key: {p: None for p in self.env.players()} for key in moment_keys} turn_players = self.env.turns() for player in self.env.players(): - if player in turn_players or self.args['observation']: + if player in turn_players or self.args["observation"]: obs = self.env.observation(player) model = models[player] outputs = model.inference(obs, hidden[player]) - hidden[player] = outputs.get('hidden', None) - v = outputs.get('value', None) + hidden[player] = outputs.get("hidden", None) + v = outputs.get("value", None) - moment['observation'][player] = obs - moment['value'][player] = v + moment["observation"][player] = obs + moment["value"][player] = v if player in turn_players: - p_ = outputs['policy'] + p_ = outputs["policy"] legal_actions = self.env.legal_actions(player) action_mask = np.ones_like(p_) * 1e32 action_mask[legal_actions] = 0 p = p_ - action_mask action = random.choices(legal_actions, weights=softmax(p[legal_actions]))[0] - moment['policy'][player] = p - moment['action_mask'][player] = action_mask - moment['action'][player] = action + moment["policy"][player] = p + moment["action_mask"][player] = action_mask + moment["action"][player] = action - err = self.env.step(moment['action']) + err = self.env.step(moment["action"]) if err: return None reward = self.env.reward() for player in self.env.players(): - moment['reward'][player] = reward.get(player, None) + moment["reward"][player] = reward.get(player, None) - moment['turn'] = turn_players + moment["turn"] = turn_players moments.append(moment) if len(moments) < 1: @@ -73,16 +73,17 @@ def generate(self, models, args): for player in self.env.players(): ret = 0 for i, m in reversed(list(enumerate(moments))): - ret = (m['reward'][player] or 0) + self.args['gamma'] * ret - moments[i]['return'][player] = ret + ret = (m["reward"][player] or 0) + self.args["gamma"] * ret + moments[i]["return"][player] = ret episode = { - 'args': args, 'steps': len(moments), - 'outcome': self.env.outcome(), - 'moment': [ - bz2.compress(pickle.dumps(moments[i:i+self.args['compress_steps']])) - for i in range(0, len(moments), self.args['compress_steps']) - ] + "args": args, + "steps": len(moments), + "outcome": self.env.outcome(), + "moment": [ + bz2.compress(pickle.dumps(moments[i : i + self.args["compress_steps"]])) + for i in range(0, len(moments), self.args["compress_steps"]) + ], } return episode @@ -90,5 +91,5 @@ def generate(self, models, args): def execute(self, models, args): episode = self.generate(models, args) if episode is None: - print('None episode in generation!') + print("None episode in generation!") return episode diff --git a/handyrl/losses.py b/handyrl/losses.py index af1b8bd6..c3128816 100755 --- a/handyrl/losses.py +++ b/handyrl/losses.py @@ -62,13 +62,13 @@ def compute_target(algorithm, values, returns, rewards, lmb, gamma, rhos, cs): if values is None: return None, 0 - if algorithm == 'MC': + if algorithm == "MC": return monte_carlo(values, returns) - elif algorithm == 'TD': + elif algorithm == "TD": return temporal_difference(values, returns, rewards, lmb, gamma) - elif algorithm == 'UPGO': + elif algorithm == "UPGO": return upgo(values, returns, rewards, lmb, gamma) - elif algorithm == 'VTRACE': + elif algorithm == "VTRACE": return vtrace(values, returns, rewards, lmb, gamma, rhos, cs) else: - print('No algorithm named %s' % algorithm) + print("No algorithm named %s" % algorithm) diff --git a/handyrl/model.py b/handyrl/model.py index 75c83ac1..0be4c011 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -5,10 +5,12 @@ import numpy as np import torch -torch.set_num_threads(1) + +# noqa idiom +if True: + torch.set_num_threads(1) import torch.nn as nn -import torch.nn.functional as F from .util import map_r @@ -27,13 +29,14 @@ def to_gpu(data): # model wrapper class + class ModelWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def init_hidden(self, batch_size=None): - if hasattr(self.model, 'init_hidden'): + if hasattr(self.model, "init_hidden"): return self.model.init_hidden(batch_size) return None @@ -42,7 +45,7 @@ def forward(self, *args, **kwargs): def inference(self, x, hidden, **kwargs): # numpy array -> numpy array - if hasattr(self.model, 'inference'): + if hasattr(self.model, "inference"): return self.model.inference(x, hidden, **kwargs) self.eval() @@ -55,10 +58,11 @@ def inference(self, x, hidden, **kwargs): # simple model + class RandomModel(nn.Module): def __init__(self, env): super().__init__() self.action_length = env.action_length() def inference(self, x=None, hidden=None): - return {'policy': np.zeros(self.action_length, dtype=np.float32), 'value': np.zeros(1, dtype=np.float32)} + return {"policy": np.zeros(self.action_length, dtype=np.float32), "value": np.zeros(1, dtype=np.float32)} diff --git a/handyrl/train.py b/handyrl/train.py index e37d64ee..fc38ead6 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -3,30 +3,29 @@ # training -import os -import time -import copy -import threading -import random import bz2 +import copy +import os import pickle +import random +import threading +import time import warnings from collections import deque import numpy as np +import psutil import torch +import torch.distributions as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributions as dist import torch.optim as optim -import psutil -from .environment import prepare_env, make_env -from .util import map_r, bimap_r, trimap_r, rotate -from .model import to_torch, to_gpu, RandomModel, ModelWrapper +from .connection import MultiProcessJobExecutor, accept_socket_connections +from .environment import make_env, prepare_env from .losses import compute_target -from .connection import MultiProcessJobExecutor -from .connection import accept_socket_connections +from .model import ModelWrapper, RandomModel, to_gpu, to_torch +from .util import bimap_r, map_r, rotate, trimap_r from .worker import WorkerCluster @@ -51,57 +50,65 @@ def replace_none(a, b): return a if a is not None else b for ep in episodes: - moments_ = sum([pickle.loads(bz2.decompress(ms)) for ms in ep['moment']], []) - moments = moments_[ep['start'] - ep['base']:ep['end'] - ep['base']] - players = list(moments[0]['observation'].keys()) - if not args['turn_based_training']: # solo training + moments_ = sum([pickle.loads(bz2.decompress(ms)) for ms in ep["moment"]], []) + moments = moments_[ep["start"] - ep["base"] : ep["end"] - ep["base"]] + players = list(moments[0]["observation"].keys()) + if not args["turn_based_training"]: # solo training players = [random.choice(players)] - obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding - p_zeros = np.zeros_like(moments[0]['policy'][moments[0]['turn'][0]]) # template for padding + obs_zeros = map_r(moments[0]["observation"][moments[0]["turn"][0]], lambda o: np.zeros_like(o)) # template for padding + p_zeros = np.zeros_like(moments[0]["policy"][moments[0]["turn"][0]]) # template for padding # data that is chainge by training configuration - if args['turn_based_training'] and not args['observation']: - obs = [[m['observation'][m['turn'][0]]] for m in moments] - p = np.array([[m['policy'][m['turn'][0]]] for m in moments]) - act = np.array([[m['action'][m['turn'][0]]] for m in moments], dtype=np.int64)[..., np.newaxis] - amask = np.array([[m['action_mask'][m['turn'][0]]] for m in moments]) + if args["turn_based_training"] and not args["observation"]: + obs = [[m["observation"][m["turn"][0]]] for m in moments] + p = np.array([[m["policy"][m["turn"][0]]] for m in moments]) + act = np.array([[m["action"][m["turn"][0]]] for m in moments], dtype=np.int64)[..., np.newaxis] + amask = np.array([[m["action_mask"][m["turn"][0]]] for m in moments]) else: - obs = [[replace_none(m['observation'][player], obs_zeros) for player in players] for m in moments] - p = np.array([[replace_none(m['policy'][player], p_zeros) for player in players] for m in moments]) - act = np.array([[replace_none(m['action'][player], 0) for player in players] for m in moments], dtype=np.int64)[..., np.newaxis] - amask = np.array([[replace_none(m['action_mask'][player], p_zeros + 1e32) for player in players] for m in moments]) + obs = [[replace_none(m["observation"][player], obs_zeros) for player in players] for m in moments] + p = np.array([[replace_none(m["policy"][player], p_zeros) for player in players] for m in moments]) + act = np.array([[replace_none(m["action"][player], 0) for player in players] for m in moments], dtype=np.int64)[ + ..., np.newaxis + ] + amask = np.array([[replace_none(m["action_mask"][player], p_zeros + 1e32) for player in players] for m in moments]) # reshape observation obs = rotate(rotate(obs)) # (T, P, ..., ...) -> (P, ..., T, ...) -> (..., T, P, ...) obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o)) # datum that is not changed by training configuration - v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1) - oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) + v = np.array( + [[replace_none(m["value"][player], [0]) for player in players] for m in moments], dtype=np.float32 + ).reshape(len(moments), len(players), -1) + rew = np.array( + [[replace_none(m["reward"][player], [0]) for player in players] for m in moments], dtype=np.float32 + ).reshape(len(moments), len(players), -1) + ret = np.array( + [[replace_none(m["return"][player], [0]) for player in players] for m in moments], dtype=np.float32 + ).reshape(len(moments), len(players), -1) + oc = np.array([ep["outcome"][player] for player in players], dtype=np.float32).reshape(1, len(players), -1) emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask - tmask = np.array([[[m['policy'][player] is not None] for player in players] for m in moments], dtype=np.float32) - omask = np.array([[[m['value'][player] is not None] for player in players] for m in moments], dtype=np.float32) + tmask = np.array([[[m["policy"][player] is not None] for player in players] for m in moments], dtype=np.float32) + omask = np.array([[[m["value"][player] is not None] for player in players] for m in moments], dtype=np.float32) - progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total'] + progress = np.arange(ep["start"], ep["end"], dtype=np.float32)[..., np.newaxis] / ep["total"] # pad each array if step length is short - if len(tmask) < args['forward_steps']: - pad_len = args['forward_steps'] - len(tmask) - obs = map_r(obs, lambda o: np.pad(o, [(0, pad_len)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0)) - p = np.pad(p, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) + if len(tmask) < args["forward_steps"]: + pad_len = args["forward_steps"] - len(tmask) + obs = map_r(obs, lambda o: np.pad(o, [(0, pad_len)] + [(0, 0)] * (len(o.shape) - 1), "constant", constant_values=0)) + p = np.pad(p, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) v = np.concatenate([v, np.tile(oc, [pad_len, 1, 1])]) - act = np.pad(act, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - rew = np.pad(rew, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - ret = np.pad(ret, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - emask = np.pad(emask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - tmask = np.pad(tmask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - omask = np.pad(omask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0) - amask = np.pad(amask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=1e32) - progress = np.pad(progress, [(0, pad_len), (0, 0)], 'constant', constant_values=1) + act = np.pad(act, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) + rew = np.pad(rew, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) + ret = np.pad(ret, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) + emask = np.pad(emask, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) + tmask = np.pad(tmask, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) + omask = np.pad(omask, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=0) + amask = np.pad(amask, [(0, pad_len), (0, 0), (0, 0)], "constant", constant_values=1e32) + progress = np.pad(progress, [(0, pad_len), (0, 0)], "constant", constant_values=1) obss.append(obs) datum.append((p, v, act, oc, rew, ret, emask, tmask, omask, amask, progress)) @@ -122,14 +129,18 @@ def replace_none(a, b): progress = to_torch(np.array(progress)) return { - 'observation': obs, - 'policy': p, 'value': v, - 'action': act, 'outcome': oc, - 'reward': rew, 'return': ret, - 'episode_mask': emask, - 'turn_mask': tmask, 'observation_mask': omask, - 'action_mask': amask, - 'progress': progress, + "observation": obs, + "policy": p, + "value": v, + "action": act, + "outcome": oc, + "reward": rew, + "return": ret, + "episode_mask": emask, + "turn_mask": tmask, + "observation_mask": omask, + "action_mask": amask, + "progress": progress, } @@ -145,7 +156,7 @@ def forward_prediction(model, hidden, batch, args): tuple: batch outputs of neural network """ - observations = batch['observation'] # (B, T, P, ...) + observations = batch["observation"] # (B, T, P, ...) if hidden is None: # feed-forward neural network @@ -154,33 +165,35 @@ def forward_prediction(model, hidden, batch, args): else: # sequential computation with RNN outputs = {} - for t in range(batch['turn_mask'].size(1)): + for t in range(batch["turn_mask"].size(1)): obs = map_r(observations, lambda o: o[:, t].reshape(-1, *o.size()[3:])) # (..., B * P, ...) - omask_ = batch['observation_mask'][:, t] + omask_ = batch["observation_mask"][:, t] omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (len(h.size()) - 2)))) hidden_ = bimap_r(hidden, omask, lambda h, m: h * m) # (..., B, P, ...) - if args['turn_based_training'] and not args['observation']: + if args["turn_based_training"] and not args["observation"]: hidden_ = map_r(hidden_, lambda h: h.sum(1)) # (..., B * 1, ...) else: hidden_ = map_r(hidden_, lambda h: h.view(-1, *h.size()[2:])) # (..., B * P, ...) outputs_ = model(obs, hidden_) for k, o in outputs_.items(): - if k == 'hidden': - next_hidden = outputs_['hidden'] + if k == "hidden": + next_hidden = outputs_["hidden"] else: outputs[k] = outputs.get(k, []) + [o] - next_hidden = bimap_r(next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:])) # (..., B, P or 1, ...) + next_hidden = bimap_r( + next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:]) + ) # (..., B, P or 1, ...) hidden = trimap_r(hidden, next_hidden, omask, lambda h, nh, m: h * (1 - m) + nh * m) outputs = {k: torch.stack(o, dim=1) for k, o in outputs.items() if o[0] is not None} for k, o in outputs.items(): - o = o.view(*batch['turn_mask'].size()[:2], -1, o.size(-1)) - if k == 'policy': + o = o.view(*batch["turn_mask"].size()[:2], -1, o.size(-1)) + if k == "policy": # gather turn player's policies - outputs[k] = o.mul(batch['turn_mask']).sum(2, keepdim=True) - batch['action_mask'] + outputs[k] = o.mul(batch["turn_mask"]).sum(2, keepdim=True) - batch["action_mask"] else: # mask valid target values and cumulative rewards - outputs[k] = o.mul(batch['observation_mask']) + outputs[k] = o.mul(batch["observation_mask"]) return outputs @@ -192,37 +205,39 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba tuple: losses and statistic values and the number of training data """ - tmasks = batch['turn_mask'] - omasks = batch['observation_mask'] + tmasks = batch["turn_mask"] + omasks = batch["observation_mask"] losses = {} dcnt = tmasks.sum().item() turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True) - losses['p'] = (-log_selected_policies * turn_advantages).sum() - if 'value' in outputs: - losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 - if 'return' in outputs: - losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum() + losses["p"] = (-log_selected_policies * turn_advantages).sum() + if "value" in outputs: + losses["v"] = ((outputs["value"] - targets["value"]) ** 2).mul(omasks).sum() / 2 + if "return" in outputs: + losses["r"] = F.smooth_l1_loss(outputs["return"], targets["return"], reduction="none").mul(omasks).sum() - entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1)) - losses['ent'] = entropy.sum() + entropy = dist.Categorical(logits=outputs["policy"]).entropy().mul(tmasks.sum(-1)) + losses["ent"] = entropy.sum() - base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0) - entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization'] - losses['total'] = base_loss + entropy_loss + base_loss = losses["p"] + losses.get("v", 0) + losses.get("r", 0) + entropy_loss = ( + entropy.mul(1 - batch["progress"] * (1 - args["entropy_regularization_decay"])).sum() * -args["entropy_regularization"] + ) + losses["total"] = base_loss + entropy_loss return losses, dcnt def compute_loss(batch, model, hidden, args): outputs = forward_prediction(model, hidden, batch, args) - actions = batch['action'] - emasks = batch['episode_mask'] + actions = batch["action"] + emasks = batch["episode_mask"] clip_rho_threshold, clip_c_threshold = 1.0, 1.0 - log_selected_b_policies = F.log_softmax(batch['policy'] , dim=-1).gather(-1, actions) * emasks - log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks + log_selected_b_policies = F.log_softmax(batch["policy"], dim=-1).gather(-1, actions) * emasks + log_selected_t_policies = F.log_softmax(outputs["policy"], dim=-1).gather(-1, actions) * emasks # thresholds of importance sampling log_rhos = log_selected_t_policies.detach() - log_selected_b_policies @@ -231,26 +246,36 @@ def compute_loss(batch, model, hidden, args): cs = torch.clamp(rhos, 0, clip_c_threshold) outputs_nograd = {k: o.detach() for k, o in outputs.items()} - if 'value' in outputs_nograd: - values_nograd = outputs_nograd['value'] - if args['turn_based_training'] and values_nograd.size(2) == 2: # two player zerosum game + if "value" in outputs_nograd: + values_nograd = outputs_nograd["value"] + if args["turn_based_training"] and values_nograd.size(2) == 2: # two player zerosum game values_nograd_opponent = -torch.stack([values_nograd[:, :, 1], values_nograd[:, :, 0]], dim=2) - values_nograd = (values_nograd + values_nograd_opponent) / (batch['observation_mask'].sum(dim=2, keepdim=True) + 1e-8) - outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks) + values_nograd = (values_nograd + values_nograd_opponent) / ( + batch["observation_mask"].sum(dim=2, keepdim=True) + 1e-8 + ) + outputs_nograd["value"] = values_nograd * emasks + batch["outcome"] * (1 - emasks) # compute targets and advantage targets = {} advantages = {} - value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs - return_args = outputs_nograd.get('return', None), batch['return'], batch['reward'], args['lambda'], args['gamma'], clipped_rhos, cs + value_args = outputs_nograd.get("value", None), batch["outcome"], None, args["lambda"], 1, clipped_rhos, cs + return_args = ( + outputs_nograd.get("return", None), + batch["return"], + batch["reward"], + args["lambda"], + args["gamma"], + clipped_rhos, + cs, + ) - targets['value'], advantages['value'] = compute_target(args['value_target'], *value_args) - targets['return'], advantages['return'] = compute_target(args['value_target'], *return_args) + targets["value"], advantages["value"] = compute_target(args["value_target"], *value_args) + targets["return"], advantages["return"] = compute_target(args["value_target"], *return_args) - if args['policy_target'] != args['value_target']: - _, advantages['value'] = compute_target(args['policy_target'], *value_args) - _, advantages['return'] = compute_target(args['policy_target'], *return_args) + if args["policy_target"] != args["value_target"]: + _, advantages["value"] = compute_target(args["policy_target"], *value_args) + _, advantages["return"] = compute_target(args["policy_target"], *return_args) # compute policy advantage total_advantages = clipped_rhos * sum(advantages.values()) @@ -264,40 +289,43 @@ def __init__(self, args, episodes): self.episodes = episodes self.shutdown_flag = False - self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args['num_batchers'], num_receivers=2) + self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args["num_batchers"], num_receivers=2) def _selector(self): while True: - yield [self.select_episode() for _ in range(self.args['batch_size'])] + yield [self.select_episode() for _ in range(self.args["batch_size"])] def _worker(self, conn, bid): - print('started batcher %d' % bid) + print("started batcher %d" % bid) while not self.shutdown_flag: episodes = conn.recv() batch = make_batch(episodes, self.args) conn.send(batch) - print('finished batcher %d' % bid) + print("finished batcher %d" % bid) def run(self): self.executor.start() def select_episode(self): while True: - ep_idx = random.randrange(min(len(self.episodes), self.args['maximum_episodes'])) - accept_rate = 1 - (len(self.episodes) - 1 - ep_idx) / self.args['maximum_episodes'] + ep_idx = random.randrange(min(len(self.episodes), self.args["maximum_episodes"])) + accept_rate = 1 - (len(self.episodes) - 1 - ep_idx) / self.args["maximum_episodes"] if random.random() < accept_rate: break ep = self.episodes[ep_idx] - turn_candidates = 1 + max(0, ep['steps'] - self.args['forward_steps']) # change start turn by sequence length + turn_candidates = 1 + max(0, ep["steps"] - self.args["forward_steps"]) # change start turn by sequence length st = random.randrange(turn_candidates) - ed = min(st + self.args['forward_steps'], ep['steps']) - st_block = st // self.args['compress_steps'] - ed_block = (ed - 1) // self.args['compress_steps'] + 1 + ed = min(st + self.args["forward_steps"], ep["steps"]) + st_block = st // self.args["compress_steps"] + ed_block = (ed - 1) // self.args["compress_steps"] + 1 ep_minimum = { - 'args': ep['args'], 'outcome': ep['outcome'], - 'moment': ep['moment'][st_block:ed_block], - 'base': st_block * self.args['compress_steps'], - 'start': st, 'end': ed, 'total': ep['steps'] + "args": ep["args"], + "outcome": ep["outcome"], + "moment": ep["moment"][st_block:ed_block], + "base": st_block * self.args["compress_steps"], + "start": st, + "end": ed, + "total": ep["steps"], } return ep_minimum @@ -316,7 +344,7 @@ def __init__(self, args, model): self.gpu = torch.cuda.device_count() self.model = model self.default_lr = 3e-8 - self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] + self.data_cnt_ema = self.args["batch_size"] * self.args["forward_steps"] self.params = list(self.model.parameters()) lr = self.default_lr * self.data_cnt_ema self.optimizer = optim.Adam(self.params, lr=lr, weight_decay=1e-5) if len(self.params) > 0 else None @@ -328,7 +356,7 @@ def __init__(self, args, model): self.shutdown_flag = False def update(self): - if len(self.episodes) < self.args['minimum_episodes']: + if len(self.episodes) < self.args["minimum_episodes"]: return None, 0 # return None before training self.update_flag = True while True: @@ -370,8 +398,8 @@ def train(self): while data_cnt == 0 or not (self.update_flag or self.shutdown_flag): # episodes were only tuple of arrays batch = self.batcher.batch() - batch_size = batch['value'].size(0) - player_count = batch['value'].size(2) + batch_size = batch["value"].size(0) + player_count = batch["value"].size(2) hidden = model.init_hidden([batch_size, player_count]) if self.gpu > 0: batch = to_gpu(batch) @@ -380,7 +408,7 @@ def train(self): losses, dcnt = compute_loss(batch, train_model, hidden, self.args) self.optimizer.zero_grad() - losses['total'].backward() + losses["total"].backward() nn.utils.clip_grad_norm_(self.params, 4.0) self.optimizer.step() @@ -391,48 +419,48 @@ def train(self): self.steps += 1 - print('loss = %s' % ' '.join([k + ':' + '%.3f' % (l / data_cnt) for k, l in loss_sum.items()])) + print("loss = %s" % " ".join([k + ":" + "%.3f" % (l / data_cnt) for k, l in loss_sum.items()])) self.data_cnt_ema = self.data_cnt_ema * 0.8 + data_cnt / (1e-2 + batch_cnt) * 0.2 for param_group in self.optimizer.param_groups: - param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) + param_group["lr"] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5) self.model.cpu() self.model.eval() return copy.deepcopy(self.model) def run(self): - print('waiting training') + print("waiting training") while not self.shutdown_flag: - if len(self.episodes) < self.args['minimum_episodes']: + if len(self.episodes) < self.args["minimum_episodes"]: time.sleep(1) continue if self.steps == 0: self.batcher.run() - print('started training') + print("started training") model = self.train() self.report_update(model, self.steps) - print('finished training') + print("finished training") class Learner: def __init__(self, args, env=None, net=None, remote=False): - train_args = args['train_args'] - env_args = args['env_args'] - train_args['env'] = env_args + train_args = args["train_args"] + env_args = args["env_args"] + train_args["env"] = env_args args = train_args - args['remote'] = remote + args["remote"] = remote self.args = args - random.seed(args['seed']) + random.seed(args["seed"]) - self.env = env(args['env']) if env is not None else make_env(env_args) - eval_modify_rate = (args['update_episodes'] ** 0.85) / args['update_episodes'] - self.eval_rate = max(args['eval_rate'], eval_modify_rate) + self.env = env(args["env"]) if env is not None else make_env(env_args) + eval_modify_rate = (args["update_episodes"] ** 0.85) / args["update_episodes"] + self.eval_rate = max(args["eval_rate"], eval_modify_rate) self.shutdown_flag = False self.flags = set() # trained datum - self.model_era = self.args['restart_epoch'] + self.model_era = self.args["restart_epoch"] self.model_class = net if net is not None else self.env.net() train_model = self.model_class() if self.model_era == 0: @@ -463,17 +491,17 @@ def shutdown(self): thread.join() def model_path(self, model_id): - return os.path.join('models', str(model_id) + '.pth') + return os.path.join("models", str(model_id) + ".pth") def latest_model_path(self): - return os.path.join('models', 'latest.pth') + return os.path.join("models", "latest.pth") def update_model(self, model, steps): # get latest model and save it - print('updated model(%d)' % steps) + print("updated model(%d)" % steps) self.model_era += 1 self.model = model - os.makedirs('models', exist_ok=True) + os.makedirs("models", exist_ok=True) torch.save(model.state_dict(), self.model_path(self.model_era)) torch.save(model.state_dict(), self.latest_model_path()) @@ -482,9 +510,9 @@ def feed_episodes(self, episodes): for episode in episodes: if episode is None: continue - for p in episode['args']['player']: - model_id = episode['args']['model_id'][p] - outcome = episode['outcome'][p] + for p in episode["args"]["player"]: + model_id = episode["args"]["model_id"][p] + outcome = episode["outcome"][p] n, r, r2 = self.generation_results.get(model_id, (0, 0, 0)) self.generation_results[model_id] = n + 1, r + outcome, r2 + outcome ** 2 @@ -492,11 +520,11 @@ def feed_episodes(self, episodes): mem = psutil.virtual_memory() mem_used_ratio = mem.used / mem.total mem_ok = mem_used_ratio <= 0.95 - maximum_episodes = self.args['maximum_episodes'] if mem_ok else len(self.trainer.episodes) + maximum_episodes = self.args["maximum_episodes"] if mem_ok else len(self.trainer.episodes) - if not mem_ok and 'memory_over' not in self.flags: + if not mem_ok and "memory_over" not in self.flags: warnings.warn("memory usage %.1f%% with buffer size %d" % (mem_used_ratio * 100, len(self.trainer.episodes))) - self.flags.add('memory_over') + self.flags.add("memory_over") self.trainer.episodes.extend([e for e in episodes if e is not None]) while len(self.trainer.episodes) > maximum_episodes: @@ -507,31 +535,31 @@ def feed_results(self, results): for result in results: if result is None: continue - for p in result['args']['player']: - model_id = result['args']['model_id'][p] - res = result['result'][p] + for p in result["args"]["player"]: + model_id = result["args"]["model_id"][p] + res = result["result"][p] n, r, r2 = self.results.get(model_id, (0, 0, 0)) self.results[model_id] = n + 1, r + res, r2 + res ** 2 def update(self): # call update to every component print() - print('epoch %d' % self.model_era) + print("epoch %d" % self.model_era) if self.model_era not in self.results: - print('win rate = Nan (0)') + print("win rate = Nan (0)") else: n, r, r2 = self.results[self.model_era] mean = r / (n + 1e-6) - print('win rate = %.3f (%.1f / %d)' % ((mean + 1) / 2, (r + n) / 2, n)) + print("win rate = %.3f (%.1f / %d)" % ((mean + 1) / 2, (r + n) / 2, n)) if self.model_era not in self.generation_results: - print('generation stats = Nan (0)') + print("generation stats = Nan (0)") else: n, r, r2 = self.generation_results[self.model_era] mean = r / (n + 1e-6) std = (r2 / (n + 1e-6) - mean ** 2) ** 0.5 - print('generation stats = %.3f +- %.3f' % (mean, std)) + print("generation stats = %.3f +- %.3f" % (mean, std)) model, steps = self.trainer.update() if model is None: @@ -544,11 +572,11 @@ def update(self): def server(self): # central conductor server # returns as list if getting multiple requests as list - print('started server') - prev_update_episodes = self.args['minimum_episodes'] - while self.model_era < self.args['epochs'] or self.args['epochs'] < 0: + print("started server") + prev_update_episodes = self.args["minimum_episodes"] + while self.model_era < self.args["epochs"] or self.args["epochs"] < 0: # no update call before storing minimum number of episodes + 1 age - next_update_episodes = prev_update_episodes + self.args['update_episodes'] + next_update_episodes = prev_update_episodes + self.args["update_episodes"] while not self.shutdown_flag and self.num_episodes < next_update_episodes: conn, (req, data) = self.worker.recv() multi_req = isinstance(data, list) @@ -556,58 +584,58 @@ def server(self): data = [data] send_data = [] - if req == 'args': + if req == "args": for _ in data: - args = {'model_id': {}} + args = {"model_id": {}} # decide role if self.num_results < self.eval_rate * self.num_episodes: - args['role'] = 'e' + args["role"] = "e" else: - args['role'] = 'g' + args["role"] = "g" - if args['role'] == 'g': + if args["role"] == "g": # genatation configuration - args['player'] = self.env.players() + args["player"] = self.env.players() for p in self.env.players(): - if p in args['player']: - args['model_id'][p] = self.model_era + if p in args["player"]: + args["model_id"][p] = self.model_era else: - args['model_id'][p] = -1 + args["model_id"][p] = -1 self.num_episodes += 1 if self.num_episodes % 100 == 0: - print(self.num_episodes, end=' ', flush=True) + print(self.num_episodes, end=" ", flush=True) - elif args['role'] == 'e': + elif args["role"] == "e": # evaluation configuration - args['player'] = [self.env.players()[self.num_results % len(self.env.players())]] + args["player"] = [self.env.players()[self.num_results % len(self.env.players())]] for p in self.env.players(): - if p in args['player']: - args['model_id'][p] = self.model_era + if p in args["player"]: + args["model_id"][p] = self.model_era else: - args['model_id'][p] = -1 + args["model_id"][p] = -1 self.num_results += 1 send_data.append(args) - elif req == 'episode': + elif req == "episode": # report generated episodes self.feed_episodes(data) send_data = [None] * len(data) - elif req == 'result': + elif req == "result": # report evaluation results self.feed_results(data) send_data = [None] * len(data) - elif req == 'model': + elif req == "model": for model_id in data: model = self.model if model_id != self.model_era: try: model = self.model_class() model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) - except: + except Exception: # return latest model if failed to load specified model pass send_data.append(pickle.dumps(model)) @@ -617,28 +645,28 @@ def server(self): self.worker.send(conn, send_data) prev_update_episodes = next_update_episodes self.update() - print('finished server') + print("finished server") def entry_server(self): port = 9999 - print('started entry server %d' % port) + print("started entry server %d" % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) while not self.shutdown_flag: conn = next(conn_acceptor) if conn is not None: worker_args = conn.recv() - print('accepted connection from %s!' % worker_args['address']) + print("accepted connection from %s!" % worker_args["address"]) args = copy.deepcopy(self.args) - args['worker'] = worker_args + args["worker"] = worker_args conn.send(args) conn.close() - print('finished entry server') + print("finished entry server") def run(self): try: # open threads self.threads = [threading.Thread(target=self.trainer.run)] - if self.args['remote']: + if self.args["remote"]: self.threads.append(threading.Thread(target=self.entry_server)) for thread in self.threads: thread.start() @@ -651,7 +679,7 @@ def run(self): def train_main(args): - prepare_env(args['env_args']) # preparing environment is needed in stand-alone mode + prepare_env(args["env_args"]) # preparing environment is needed in stand-alone mode learner = Learner(args=args) learner.run() diff --git a/handyrl/util.py b/handyrl/util.py index c3aaf564..d4d45d0d 100755 --- a/handyrl/util.py +++ b/handyrl/util.py @@ -34,26 +34,18 @@ def rotate(x, max_depth=1024): return x if isinstance(x, (list, tuple)): if isinstance(x[0], (list, tuple)): - return type(x[0])( - rotate(type(x)(xx[i] for xx in x), max_depth - 1) - for i, _ in enumerate(x[0]) - ) + return type(x[0])(rotate(type(x)(xx[i] for xx in x), max_depth - 1) for i, _ in enumerate(x[0])) elif isinstance(x[0], dict): - return type(x[0])( - (key, rotate(type(x)(xx[key] for xx in x), max_depth - 1)) - for key in x[0] - ) + return type(x[0])((key, rotate(type(x)(xx[key] for xx in x), max_depth - 1)) for key in x[0]) elif isinstance(x, dict): x_front = x[list(x.keys())[0]] if isinstance(x_front, (list, tuple)): return type(x_front)( - rotate(type(x)((key, xx[i]) for key, xx in x.items()), max_depth - 1) - for i, _ in enumerate(x_front) + rotate(type(x)((key, xx[i]) for key, xx in x.items()), max_depth - 1) for i, _ in enumerate(x_front) ) elif isinstance(x_front, dict): return type(x_front)( - (key2, rotate(type(x)((key1, xx[key2]) for key1, xx in x.items()), max_depth - 1)) - for key2 in x_front + (key2, rotate(type(x)((key1, xx[key2]) for key1, xx in x.items()), max_depth - 1)) for key2 in x_front ) return x diff --git a/handyrl/worker.py b/handyrl/worker.py index 7d4557c3..e84ce38b 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -3,19 +3,23 @@ # worker and gather +import functools +import multiprocessing as mp +import pickle import random import threading import time -import functools -from socket import gethostname from collections import deque -import multiprocessing as mp -import pickle +from socket import gethostname -from .environment import prepare_env, make_env -from .connection import QueueCommunicator -from .connection import send_recv, open_multiprocessing_connections -from .connection import connect_socket_connection, accept_socket_connections +from .connection import ( + QueueCommunicator, + accept_socket_connections, + connect_socket_connection, + open_multiprocessing_connections, + send_recv, +) +from .environment import make_env, prepare_env from .evaluation import Evaluator from .generation import Generator from .model import ModelWrapper @@ -23,20 +27,20 @@ class Worker: def __init__(self, args, conn, wid): - print('opened worker %d' % wid) + print("opened worker %d" % wid) self.worker_id = wid self.args = args self.conn = conn self.latest_model = -1, None - env = make_env({**args['env'], 'id': wid}) + env = make_env({**args["env"], "id": wid}) self.generator = Generator(env, self.args) self.evaluator = Evaluator(env, self.args) - random.seed(args['seed'] + wid) + random.seed(args["seed"] + wid) def __del__(self): - print('closed worker %d' % self.worker_id) + print("closed worker %d" % self.worker_id) def _gather_models(self, model_ids): model_pool = {} @@ -49,7 +53,7 @@ def _gather_models(self, model_ids): model_pool[model_id] = self.latest_model[1] else: # get model from server - model_pool[model_id] = ModelWrapper(pickle.loads(send_recv(self.conn, ('model', model_id)))) + model_pool[model_id] = ModelWrapper(pickle.loads(send_recv(self.conn, ("model", model_id)))) # update latest model if model_id > self.latest_model[0]: self.latest_model = model_id, model_pool[model_id] @@ -57,24 +61,24 @@ def _gather_models(self, model_ids): def run(self): while True: - args = send_recv(self.conn, ('args', None)) - role = args['role'] + args = send_recv(self.conn, ("args", None)) + role = args["role"] models = {} - if 'model_id' in args: - model_ids = list(args['model_id'].values()) + if "model_id" in args: + model_ids = list(args["model_id"].values()) model_pool = self._gather_models(model_ids) # make dict of models - for p, model_id in args['model_id'].items(): + for p, model_id in args["model_id"].items(): models[p] = model_pool[model_id] - if role == 'g': + if role == "g": episode = self.generator.execute(models, args) - send_recv(self.conn, ('episode', episode)) - elif role == 'e': + send_recv(self.conn, ("episode", episode)) + elif role == "e": result = self.evaluator.execute(models, args) - send_recv(self.conn, ('result', result)) + send_recv(self.conn, ("result", result)) def make_worker_args(args, n_ga, gaid, wid, conn): @@ -88,22 +92,20 @@ def open_worker(args, conn, wid): class Gather(QueueCommunicator): def __init__(self, args, conn, gaid): - print('started gather %d' % gaid) + print("started gather %d" % gaid) super().__init__() self.gather_id = gaid self.server_conn = conn self.args_queue = deque([]) - self.data_map = {'model': {}} + self.data_map = {"model": {}} self.result_send_map = {} self.result_send_cnt = 0 - n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] + n_pro, n_ga = args["worker"]["num_parallel"], args["worker"]["num_gathers"] num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga) worker_conns = open_multiprocessing_connections( - num_workers_per_gather, - open_worker, - functools.partial(make_worker_args, args, n_ga, gaid) + num_workers_per_gather, open_worker, functools.partial(make_worker_args, args, n_ga, gaid) ) for conn in worker_conns: @@ -113,12 +115,12 @@ def __init__(self, args, conn, gaid): self.result_buf_len = 1 + len(worker_conns) // 4 def __del__(self): - print('finished gather %d' % self.gather_id) + print("finished gather %d" % self.gather_id) def run(self): while True: conn, (command, args) = self.recv() - if command == 'args': + if command == "args": # When requested arguments, return buffered outputs if len(self.args_queue) == 0: # get multiple arguments from server and store them @@ -167,24 +169,25 @@ def __init__(self, args): self.args = args def run(self): - if self.args['remote']: + if self.args["remote"]: # prepare listening connections def worker_server(port): conn_acceptor = accept_socket_connections(port=port, timeout=0.5) - print('started worker server %d' % port) + print("started worker server %d" % port) while not self.shutdown_flag: # use super class's flag conn = next(conn_acceptor) if conn is not None: self.add_connection(conn) - print('finished worker server') + print("finished worker server") + # use super class's thread list self.threads.append(threading.Thread(target=worker_server, args=(9998,))) self.threads[-1].start() else: # open local connections - if 'num_gathers' not in self.args['worker']: - self.args['worker']['num_gathers'] = 1 + max(0, self.args['worker']['num_parallel'] - 1) // 16 - for i in range(self.args['worker']['num_gathers']): + if "num_gathers" not in self.args["worker"]: + self.args["worker"]["num_gathers"] = 1 + max(0, self.args["worker"]["num_parallel"] - 1) // 16 + for i in range(self.args["worker"]["num_gathers"]): conn0, conn1 = mp.Pipe(duplex=True) mp.Process(target=gather_loop, args=(self.args, conn1, i)).start() conn1.close() @@ -192,7 +195,7 @@ def worker_server(port): def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) + conn = connect_socket_connection(worker_args["server_address"], 9999) conn.send(worker_args) args = conn.recv() conn.close() @@ -201,20 +204,20 @@ def entry(worker_args): def worker_main(args): # offline generation worker - worker_args = args['worker_args'] - worker_args['address'] = gethostname() - if 'num_gathers' not in worker_args: - worker_args['num_gathers'] = 1 + max(0, worker_args['num_parallel'] - 1) // 16 + worker_args = args["worker_args"] + worker_args["address"] = gethostname() + if "num_gathers" not in worker_args: + worker_args["num_gathers"] = 1 + max(0, worker_args["num_parallel"] - 1) // 16 args = entry(worker_args) print(args) - prepare_env(args['env']) + prepare_env(args["env"]) # open workers process = [] try: - for i in range(args['worker']['num_gathers']): - conn = connect_socket_connection(args['worker']['server_address'], 9998) + for i in range(args["worker"]["num_gathers"]): + conn = connect_socket_connection(args["worker"]["server_address"], 9998) p = mp.Process(target=gather_loop, args=(args, conn, i)) p.start() conn.close() diff --git a/main.py b/main.py index 73a6d13b..3cbf7f28 100755 --- a/main.py +++ b/main.py @@ -3,39 +3,45 @@ import os import sys -import yaml +import yaml -if __name__ == '__main__': - os.environ['OMP_NUM_THREADS'] = '1' +if __name__ == "__main__": + os.environ["OMP_NUM_THREADS"] = "1" - with open('config.yaml') as f: + with open("config.yaml") as f: args = yaml.safe_load(f) print(args) if len(sys.argv) < 2: - print('Please set mode of HandyRL.') + print("Please set mode of HandyRL.") exit(1) mode = sys.argv[1] - if mode == '--train' or mode == '-t': + if mode == "--train" or mode == "-t": from handyrl.train import train_main as main + main(args) - elif mode == '--train-server' or mode == '-ts': + elif mode == "--train-server" or mode == "-ts": from handyrl.train import train_server_main as main + main(args) - elif mode == '--worker' or mode == '-w': + elif mode == "--worker" or mode == "-w": from handyrl.worker import worker_main as main + main(args) - elif mode == '--eval' or mode == '-e': + elif mode == "--eval" or mode == "-e": from handyrl.evaluation import eval_main as main + main(args, sys.argv[2:]) - elif mode == '--eval-server' or mode == '-es': + elif mode == "--eval-server" or mode == "-es": from handyrl.evaluation import eval_server_main as main + main(args, sys.argv[2:]) - elif mode == '--eval-client' or mode == '-ec': + elif mode == "--eval-client" or mode == "-ec": from handyrl.evaluation import eval_client_main as main + main(args, sys.argv[2:]) else: - print('Not found mode %s.' % mode) + print("Not found mode %s." % mode) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 00000000..0cf45282 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,891 @@ +[[package]] +name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "atomicwrites" +version = "1.4.0" +description = "Atomic file writes." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "attrs" +version = "20.3.0" +description = "Classes Without Boilerplate" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.extras] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] +docs = ["furo", "sphinx", "zope.interface"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] + +[[package]] +name = "black" +version = "20.8b1" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +appdirs = "*" +click = ">=7.1.2" +mypy-extensions = ">=0.4.3" +pathspec = ">=0.6,<1" +regex = ">=2020.1.8" +toml = ">=0.10.1" +typed-ast = ">=1.4.0" +typing-extensions = ">=3.7.4" + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] + +[[package]] +name = "certifi" +version = "2020.12.5" +description = "Python package for providing Mozilla's CA Bundle." +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "chardet" +version = "4.0.0" +description = "Universal encoding detector for Python 2 and 3" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "click" +version = "7.1.2" +description = "Composable command line interface toolkit" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "colorama" +version = "0.4.4" +description = "Cross-platform colored terminal text." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "colorlog" +version = "4.8.0" +description = "Log formatting with colors!" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} + +[[package]] +name = "dacite" +version = "1.6.0" +description = "Simple creation of data classes from dictionaries." +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.extras] +dev = ["pytest (>=5)", "pytest-cov", "coveralls", "black", "mypy", "pylint"] + +[[package]] +name = "flake8" +version = "3.9.1" +description = "the modular source code checker: pep8 pyflakes and co" +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[package.dependencies] +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +mccabe = ">=0.6.0,<0.7.0" +pycodestyle = ">=2.7.0,<2.8.0" +pyflakes = ">=2.3.0,<2.4.0" + +[[package]] +name = "flake8-bugbear" +version = "21.4.3" +description = "A plugin for flake8 finding likely bugs and design problems in your program. Contains warnings that don't belong in pyflakes and pycodestyle." +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +attrs = ">=19.2.0" +flake8 = ">=3.0.0" + +[package.extras] +dev = ["coverage", "black", "hypothesis", "hypothesmith"] + +[[package]] +name = "gitdb" +version = "4.0.7" +description = "Git Object Database" +category = "dev" +optional = false +python-versions = ">=3.4" + +[package.dependencies] +smmap = ">=3.0.1,<5" + +[[package]] +name = "gitpython" +version = "3.1.14" +description = "Python Git Library" +category = "dev" +optional = false +python-versions = ">=3.4" + +[package.dependencies] +gitdb = ">=4.0.1,<5" + +[[package]] +name = "idna" +version = "2.10" +description = "Internationalized Domain Names in Applications (IDNA)" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "importlib-metadata" +version = "3.10.0" +description = "Read metadata from Python packages" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} +zipp = ">=0.5" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] + +[[package]] +name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "isort" +version = "5.1.4" +description = "A Python utility / library to sort Python imports." +category = "dev" +optional = false +python-versions = ">=3.6,<4.0" + +[package.extras] +pipfile_deprecated_finder = ["pipreqs", "requirementslib", "tomlkit (>=0.5.3)"] +requirements_deprecated_finder = ["pipreqs", "pip-api"] + +[[package]] +name = "jsonschema" +version = "3.2.0" +description = "An implementation of JSON Schema validation for Python" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +attrs = ">=17.4.0" +importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +pyrsistent = ">=0.14.0" +six = ">=1.11.0" + +[package.extras] +format = ["idna", "jsonpointer (>1.13)", "rfc3987", "strict-rfc3339", "webcolors"] +format_nongpl = ["idna", "jsonpointer (>1.13)", "webcolors", "rfc3986-validator (>0.1.0)", "rfc3339-validator"] + +[[package]] +name = "kaggle-environments" +version = "1.7.11" +description = "Kaggle Environments" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +jsonschema = ">=3.0.1" + +[[package]] +name = "mccabe" +version = "0.6.1" +description = "McCabe checker, plugin for flake8" +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "mypy" +version = "0.790" +description = "Optional static typing for Python" +category = "dev" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +mypy-extensions = ">=0.4.3,<0.5.0" +typed-ast = ">=1.4.0,<1.5.0" +typing-extensions = ">=3.7.4" + +[package.extras] +dmypy = ["psutil (>=4.0)"] + +[[package]] +name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "numpy" +version = "1.20.2" +description = "NumPy is the fundamental package for array computing with Python." +category = "main" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "packaging" +version = "20.9" +description = "Core utilities for Python packages" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.dependencies] +pyparsing = ">=2.0.2" + +[[package]] +name = "pathspec" +version = "0.8.1" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "pluggy" +version = "0.13.1" +description = "plugin and hook calling mechanisms for python" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.dependencies] +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} + +[package.extras] +dev = ["pre-commit", "tox"] + +[[package]] +name = "psutil" +version = "5.8.0" +description = "Cross-platform lib for process and system monitoring in Python." +category = "main" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.extras] +test = ["ipaddress", "mock", "unittest2", "enum34", "pywin32", "wmi"] + +[[package]] +name = "py" +version = "1.10.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "pycodestyle" +version = "2.7.0" +description = "Python style guide checker" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "pyflakes" +version = "2.3.1" +description = "passive checker of Python programs" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[[package]] +name = "pyparsing" +version = "2.4.7" +description = "Python parsing module" +category = "dev" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" + +[[package]] +name = "pyrsistent" +version = "0.17.3" +description = "Persistent/Functional/Immutable data structures" +category = "dev" +optional = false +python-versions = ">=3.5" + +[[package]] +name = "pysen" +version = "0.9.1" +description = "Python linting made easy. Also a casual yet honorific way to address individuals who have entered an organization prior to you." +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +black = {version = ">=19.10b0,<=20.8", optional = true, markers = "extra == \"lint\""} +colorlog = ">=4.0.0,<5.0.0" +dacite = ">=1.1.0,<2.0.0" +flake8 = {version = ">=3.7,<4", optional = true, markers = "extra == \"lint\""} +flake8-bugbear = {version = "*", optional = true, markers = "extra == \"lint\""} +GitPython = ">=3.0.0,<4.0.0" +isort = {version = ">=4.3,<5.2.0", optional = true, markers = "extra == \"lint\""} +mypy = {version = ">=0.770,<0.800", optional = true, markers = "extra == \"lint\""} +tomlkit = ">=0.5.11,<1.0.0" +unidiff = ">=0.6.0,<1.0.0" + +[package.extras] +lint = ["black (>=19.10b0,<=20.8)", "flake8-bugbear", "flake8 (>=3.7,<4)", "isort (>=4.3,<5.2.0)", "mypy (>=0.770,<0.800)"] + +[[package]] +name = "pytest" +version = "6.2.2" +description = "pytest: simple powerful testing with Python" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} +attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<1.0.0a1" +py = ">=1.8.2" +toml = "*" + +[package.extras] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] + +[[package]] +name = "pyyaml" +version = "5.4.1" +description = "YAML parser and emitter for Python" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" + +[[package]] +name = "regex" +version = "2021.4.4" +description = "Alternative regular expression module, to replace re." +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "requests" +version = "2.25.1" +description = "Python HTTP for Humans." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[package.dependencies] +certifi = ">=2017.4.17" +chardet = ">=3.0.2,<5" +idna = ">=2.5,<3" +urllib3 = ">=1.21.1,<1.27" + +[package.extras] +security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] +socks = ["PySocks (>=1.5.6,!=1.5.7)", "win-inet-pton"] + +[[package]] +name = "six" +version = "1.15.0" +description = "Python 2 and 3 compatibility utilities" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" + +[[package]] +name = "smmap" +version = "4.0.0" +description = "A pure Python implementation of a sliding window memory map manager" +category = "dev" +optional = false +python-versions = ">=3.5" + +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" + +[[package]] +name = "tomlkit" +version = "0.7.0" +description = "Style preserving TOML library" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" + +[[package]] +name = "torch" +version = "1.8.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +numpy = "*" +typing-extensions = "*" + +[[package]] +name = "typed-ast" +version = "1.4.3" +description = "a fork of Python 2 and 3 ast modules with type comment support" +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "typing-extensions" +version = "3.7.4.3" +description = "Backported and Experimental Type Hints for Python 3.5+" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "unidiff" +version = "0.6.0" +description = "Unified diff parsing/metadata extraction library." +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "urllib3" +version = "1.26.4" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" + +[package.extras] +secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] +socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +brotli = ["brotlipy (>=0.6.0)"] + +[[package]] +name = "zipp" +version = "3.4.1" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.extras] +docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] +testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "pytest-enabler", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] + +[metadata] +lock-version = "1.1" +python-versions = "^3.7" +content-hash = "1ecb3d64dd65bd7e44866dbba0c698a83e8d6ae56d1bf7031f26d187f27ca234" + +[metadata.files] +appdirs = [ + {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, + {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, +] +atomicwrites = [ + {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, + {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, +] +attrs = [ + {file = "attrs-20.3.0-py2.py3-none-any.whl", hash = "sha256:31b2eced602aa8423c2aea9c76a724617ed67cf9513173fd3a4f03e3a929c7e6"}, + {file = "attrs-20.3.0.tar.gz", hash = "sha256:832aa3cde19744e49938b91fea06d69ecb9e649c93ba974535d08ad92164f700"}, +] +black = [ + {file = "black-20.8b1.tar.gz", hash = "sha256:1c02557aa099101b9d21496f8a914e9ed2222ef70336404eeeac8edba836fbea"}, +] +certifi = [ + {file = "certifi-2020.12.5-py2.py3-none-any.whl", hash = "sha256:719a74fb9e33b9bd44cc7f3a8d94bc35e4049deebe19ba7d8e108280cfd59830"}, + {file = "certifi-2020.12.5.tar.gz", hash = "sha256:1a4995114262bffbc2413b159f2a1a480c969de6e6eb13ee966d470af86af59c"}, +] +chardet = [ + {file = "chardet-4.0.0-py2.py3-none-any.whl", hash = "sha256:f864054d66fd9118f2e67044ac8981a54775ec5b67aed0441892edb553d21da5"}, + {file = "chardet-4.0.0.tar.gz", hash = "sha256:0d6f53a15db4120f2b08c94f11e7d93d2c911ee118b6b30a04ec3ee8310179fa"}, +] +click = [ + {file = "click-7.1.2-py2.py3-none-any.whl", hash = "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc"}, + {file = "click-7.1.2.tar.gz", hash = "sha256:d2b5255c7c6349bc1bd1e59e08cd12acbbd63ce649f2588755783aa94dfb6b1a"}, +] +colorama = [ + {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, + {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, +] +colorlog = [ + {file = "colorlog-4.8.0-py2.py3-none-any.whl", hash = "sha256:3dd15cb27e8119a24c1a7b5c93f9f3b455855e0f73993b1c25921b2f646f1dcd"}, + {file = "colorlog-4.8.0.tar.gz", hash = "sha256:59b53160c60902c405cdec28d38356e09d40686659048893e026ecbd589516b1"}, +] +dacite = [ + {file = "dacite-1.6.0-py3-none-any.whl", hash = "sha256:4331535f7aabb505c732fa4c3c094313fc0a1d5ea19907bf4726a7819a68b93f"}, + {file = "dacite-1.6.0.tar.gz", hash = "sha256:d48125ed0a0352d3de9f493bf980038088f45f3f9d7498f090b50a847daaa6df"}, +] +flake8 = [ + {file = "flake8-3.9.1-py2.py3-none-any.whl", hash = "sha256:3b9f848952dddccf635be78098ca75010f073bfe14d2c6bda867154bea728d2a"}, + {file = "flake8-3.9.1.tar.gz", hash = "sha256:1aa8990be1e689d96c745c5682b687ea49f2e05a443aff1f8251092b0014e378"}, +] +flake8-bugbear = [ + {file = "flake8-bugbear-21.4.3.tar.gz", hash = "sha256:2346c81f889955b39e4a368eb7d508de723d9de05716c287dc860a4073dc57e7"}, + {file = "flake8_bugbear-21.4.3-py36.py37.py38-none-any.whl", hash = "sha256:4f305dca96be62bf732a218fe6f1825472a621d3452c5b994d8f89dae21dbafa"}, +] +gitdb = [ + {file = "gitdb-4.0.7-py3-none-any.whl", hash = "sha256:6c4cc71933456991da20917998acbe6cf4fb41eeaab7d6d67fbc05ecd4c865b0"}, + {file = "gitdb-4.0.7.tar.gz", hash = "sha256:96bf5c08b157a666fec41129e6d327235284cca4c81e92109260f353ba138005"}, +] +gitpython = [ + {file = "GitPython-3.1.14-py3-none-any.whl", hash = "sha256:3283ae2fba31c913d857e12e5ba5f9a7772bbc064ae2bb09efafa71b0dd4939b"}, + {file = "GitPython-3.1.14.tar.gz", hash = "sha256:be27633e7509e58391f10207cd32b2a6cf5b908f92d9cd30da2e514e1137af61"}, +] +idna = [ + {file = "idna-2.10-py2.py3-none-any.whl", hash = "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"}, + {file = "idna-2.10.tar.gz", hash = "sha256:b307872f855b18632ce0c21c5e45be78c0ea7ae4c15c828c20788b26921eb3f6"}, +] +importlib-metadata = [ + {file = "importlib_metadata-3.10.0-py3-none-any.whl", hash = "sha256:d2d46ef77ffc85cbf7dac7e81dd663fde71c45326131bea8033b9bad42268ebe"}, + {file = "importlib_metadata-3.10.0.tar.gz", hash = "sha256:c9db46394197244adf2f0b08ec5bc3cf16757e9590b02af1fca085c16c0d600a"}, +] +iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, + {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, +] +isort = [ + {file = "isort-5.1.4-py3-none-any.whl", hash = "sha256:ae3007f72a2e9da36febd3454d8be4b175d6ca17eb765841d5fe3d038aede79d"}, + {file = "isort-5.1.4.tar.gz", hash = "sha256:145072eedc4927cc9c1f9478f2d83b2fc1e6469df4129c02ef4e8c742207a46c"}, +] +jsonschema = [ + {file = "jsonschema-3.2.0-py2.py3-none-any.whl", hash = "sha256:4e5b3cf8216f577bee9ce139cbe72eca3ea4f292ec60928ff24758ce626cd163"}, + {file = "jsonschema-3.2.0.tar.gz", hash = "sha256:c8a85b28d377cc7737e46e2d9f2b4f44ee3c0e1deac6bf46ddefc7187d30797a"}, +] +kaggle-environments = [ + {file = "kaggle-environments-1.7.11.tar.gz", hash = "sha256:cd82d55ba298b74b28bfa7c706a3b52edefa02db92ce76f75b76f5d1cb1bc181"}, + {file = "kaggle_environments-1.7.11-py2.py3-none-any.whl", hash = "sha256:16a79e4d36c31bab73df746164f5dfdee4c20da514811bad1b8cd4aa2ae158d0"}, +] +mccabe = [ + {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, + {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, +] +mypy = [ + {file = "mypy-0.790-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:bd03b3cf666bff8d710d633d1c56ab7facbdc204d567715cb3b9f85c6e94f669"}, + {file = "mypy-0.790-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:2170492030f6faa537647d29945786d297e4862765f0b4ac5930ff62e300d802"}, + {file = "mypy-0.790-cp35-cp35m-win_amd64.whl", hash = "sha256:e86bdace26c5fe9cf8cb735e7cedfe7850ad92b327ac5d797c656717d2ca66de"}, + {file = "mypy-0.790-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e97e9c13d67fbe524be17e4d8025d51a7dca38f90de2e462243ab8ed8a9178d1"}, + {file = "mypy-0.790-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:0d34d6b122597d48a36d6c59e35341f410d4abfa771d96d04ae2c468dd201abc"}, + {file = "mypy-0.790-cp36-cp36m-win_amd64.whl", hash = "sha256:72060bf64f290fb629bd4a67c707a66fd88ca26e413a91384b18db3876e57ed7"}, + {file = "mypy-0.790-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:eea260feb1830a627fb526d22fbb426b750d9f5a47b624e8d5e7e004359b219c"}, + {file = "mypy-0.790-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:c614194e01c85bb2e551c421397e49afb2872c88b5830e3554f0519f9fb1c178"}, + {file = "mypy-0.790-cp37-cp37m-win_amd64.whl", hash = "sha256:0a0d102247c16ce93c97066443d11e2d36e6cc2a32d8ccc1f705268970479324"}, + {file = "mypy-0.790-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cf4e7bf7f1214826cf7333627cb2547c0db7e3078723227820d0a2490f117a01"}, + {file = "mypy-0.790-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:af4e9ff1834e565f1baa74ccf7ae2564ae38c8df2a85b057af1dbbc958eb6666"}, + {file = "mypy-0.790-cp38-cp38-win_amd64.whl", hash = "sha256:da56dedcd7cd502ccd3c5dddc656cb36113dd793ad466e894574125945653cea"}, + {file = "mypy-0.790-py3-none-any.whl", hash = "sha256:2842d4fbd1b12ab422346376aad03ff5d0805b706102e475e962370f874a5122"}, + {file = "mypy-0.790.tar.gz", hash = "sha256:2b21ba45ad9ef2e2eb88ce4aeadd0112d0f5026418324176fd494a6824b74975"}, +] +mypy-extensions = [ + {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, + {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, +] +numpy = [ + {file = "numpy-1.20.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e9459f40244bb02b2f14f6af0cd0732791d72232bbb0dc4bab57ef88e75f6935"}, + {file = "numpy-1.20.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:a8e6859913ec8eeef3dbe9aed3bf475347642d1cdd6217c30f28dee8903528e6"}, + {file = "numpy-1.20.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:9cab23439eb1ebfed1aaec9cd42b7dc50fc96d5cd3147da348d9161f0501ada5"}, + {file = "numpy-1.20.2-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:9c0fab855ae790ca74b27e55240fe4f2a36a364a3f1ebcfd1fb5ac4088f1cec3"}, + {file = "numpy-1.20.2-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:61d5b4cf73622e4d0c6b83408a16631b670fc045afd6540679aa35591a17fe6d"}, + {file = "numpy-1.20.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:d15007f857d6995db15195217afdbddfcd203dfaa0ba6878a2f580eaf810ecd6"}, + {file = "numpy-1.20.2-cp37-cp37m-win32.whl", hash = "sha256:d76061ae5cab49b83a8cf3feacefc2053fac672728802ac137dd8c4123397677"}, + {file = "numpy-1.20.2-cp37-cp37m-win_amd64.whl", hash = "sha256:bad70051de2c50b1a6259a6df1daaafe8c480ca98132da98976d8591c412e737"}, + {file = "numpy-1.20.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:719656636c48be22c23641859ff2419b27b6bdf844b36a2447cb39caceb00935"}, + {file = "numpy-1.20.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:aa046527c04688af680217fffac61eec2350ef3f3d7320c07fd33f5c6e7b4d5f"}, + {file = "numpy-1.20.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:2428b109306075d89d21135bdd6b785f132a1f5a3260c371cee1fae427e12727"}, + {file = "numpy-1.20.2-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:e8e4fbbb7e7634f263c5b0150a629342cc19b47c5eba8d1cd4363ab3455ab576"}, + {file = "numpy-1.20.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:edb1f041a9146dcf02cd7df7187db46ab524b9af2515f392f337c7cbbf5b52cd"}, + {file = "numpy-1.20.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:c73a7975d77f15f7f68dacfb2bca3d3f479f158313642e8ea9058eea06637931"}, + {file = "numpy-1.20.2-cp38-cp38-win32.whl", hash = "sha256:6c915ee7dba1071554e70a3664a839fbc033e1d6528199d4621eeaaa5487ccd2"}, + {file = "numpy-1.20.2-cp38-cp38-win_amd64.whl", hash = "sha256:471c0571d0895c68da309dacee4e95a0811d0a9f9f532a48dc1bea5f3b7ad2b7"}, + {file = "numpy-1.20.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4703b9e937df83f5b6b7447ca5912b5f5f297aba45f91dbbbc63ff9278c7aa98"}, + {file = "numpy-1.20.2-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:abc81829c4039e7e4c30f7897938fa5d4916a09c2c7eb9b244b7a35ddc9656f4"}, + {file = "numpy-1.20.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:377751954da04d4a6950191b20539066b4e19e3b559d4695399c5e8e3e683bf6"}, + {file = "numpy-1.20.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6e51e417d9ae2e7848314994e6fc3832c9d426abce9328cf7571eefceb43e6c9"}, + {file = "numpy-1.20.2-cp39-cp39-win32.whl", hash = "sha256:780ae5284cb770ade51d4b4a7dce4faa554eb1d88a56d0e8b9f35fca9b0270ff"}, + {file = "numpy-1.20.2-cp39-cp39-win_amd64.whl", hash = "sha256:924dc3f83de20437de95a73516f36e09918e9c9c18d5eac520062c49191025fb"}, + {file = "numpy-1.20.2-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:97ce8b8ace7d3b9288d88177e66ee75480fb79b9cf745e91ecfe65d91a856042"}, + {file = "numpy-1.20.2.zip", hash = "sha256:878922bf5ad7550aa044aa9301d417e2d3ae50f0f577de92051d739ac6096cee"}, +] +packaging = [ + {file = "packaging-20.9-py2.py3-none-any.whl", hash = "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"}, + {file = "packaging-20.9.tar.gz", hash = "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5"}, +] +pathspec = [ + {file = "pathspec-0.8.1-py2.py3-none-any.whl", hash = "sha256:aa0cb481c4041bf52ffa7b0d8fa6cd3e88a2ca4879c533c9153882ee2556790d"}, + {file = "pathspec-0.8.1.tar.gz", hash = "sha256:86379d6b86d75816baba717e64b1a3a3469deb93bb76d613c9ce79edc5cb68fd"}, +] +pluggy = [ + {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, + {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, +] +psutil = [ + {file = "psutil-5.8.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:0066a82f7b1b37d334e68697faba68e5ad5e858279fd6351c8ca6024e8d6ba64"}, + {file = "psutil-5.8.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:0ae6f386d8d297177fd288be6e8d1afc05966878704dad9847719650e44fc49c"}, + {file = "psutil-5.8.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:12d844996d6c2b1d3881cfa6fa201fd635971869a9da945cf6756105af73d2df"}, + {file = "psutil-5.8.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:02b8292609b1f7fcb34173b25e48d0da8667bc85f81d7476584d889c6e0f2131"}, + {file = "psutil-5.8.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:6ffe81843131ee0ffa02c317186ed1e759a145267d54fdef1bc4ea5f5931ab60"}, + {file = "psutil-5.8.0-cp27-none-win32.whl", hash = "sha256:ea313bb02e5e25224e518e4352af4bf5e062755160f77e4b1767dd5ccb65f876"}, + {file = "psutil-5.8.0-cp27-none-win_amd64.whl", hash = "sha256:5da29e394bdedd9144c7331192e20c1f79283fb03b06e6abd3a8ae45ffecee65"}, + {file = "psutil-5.8.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:74fb2557d1430fff18ff0d72613c5ca30c45cdbfcddd6a5773e9fc1fe9364be8"}, + {file = "psutil-5.8.0-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:74f2d0be88db96ada78756cb3a3e1b107ce8ab79f65aa885f76d7664e56928f6"}, + {file = "psutil-5.8.0-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:99de3e8739258b3c3e8669cb9757c9a861b2a25ad0955f8e53ac662d66de61ac"}, + {file = "psutil-5.8.0-cp36-cp36m-win32.whl", hash = "sha256:36b3b6c9e2a34b7d7fbae330a85bf72c30b1c827a4366a07443fc4b6270449e2"}, + {file = "psutil-5.8.0-cp36-cp36m-win_amd64.whl", hash = "sha256:52de075468cd394ac98c66f9ca33b2f54ae1d9bff1ef6b67a212ee8f639ec06d"}, + {file = "psutil-5.8.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c6a5fd10ce6b6344e616cf01cc5b849fa8103fbb5ba507b6b2dee4c11e84c935"}, + {file = "psutil-5.8.0-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:61f05864b42fedc0771d6d8e49c35f07efd209ade09a5afe6a5059e7bb7bf83d"}, + {file = "psutil-5.8.0-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:0dd4465a039d343925cdc29023bb6960ccf4e74a65ad53e768403746a9207023"}, + {file = "psutil-5.8.0-cp37-cp37m-win32.whl", hash = "sha256:1bff0d07e76114ec24ee32e7f7f8d0c4b0514b3fae93e3d2aaafd65d22502394"}, + {file = "psutil-5.8.0-cp37-cp37m-win_amd64.whl", hash = "sha256:fcc01e900c1d7bee2a37e5d6e4f9194760a93597c97fee89c4ae51701de03563"}, + {file = "psutil-5.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6223d07a1ae93f86451d0198a0c361032c4c93ebd4bf6d25e2fb3edfad9571ef"}, + {file = "psutil-5.8.0-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:d225cd8319aa1d3c85bf195c4e07d17d3cd68636b8fc97e6cf198f782f99af28"}, + {file = "psutil-5.8.0-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:28ff7c95293ae74bf1ca1a79e8805fcde005c18a122ca983abf676ea3466362b"}, + {file = "psutil-5.8.0-cp38-cp38-win32.whl", hash = "sha256:ce8b867423291cb65cfc6d9c4955ee9bfc1e21fe03bb50e177f2b957f1c2469d"}, + {file = "psutil-5.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:90f31c34d25b1b3ed6c40cdd34ff122b1887a825297c017e4cbd6796dd8b672d"}, + {file = "psutil-5.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6323d5d845c2785efb20aded4726636546b26d3b577aded22492908f7c1bdda7"}, + {file = "psutil-5.8.0-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:245b5509968ac0bd179287d91210cd3f37add77dad385ef238b275bad35fa1c4"}, + {file = "psutil-5.8.0-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:90d4091c2d30ddd0a03e0b97e6a33a48628469b99585e2ad6bf21f17423b112b"}, + {file = "psutil-5.8.0-cp39-cp39-win32.whl", hash = "sha256:ea372bcc129394485824ae3e3ddabe67dc0b118d262c568b4d2602a7070afdb0"}, + {file = "psutil-5.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:f4634b033faf0d968bb9220dd1c793b897ab7f1189956e1aa9eae752527127d3"}, + {file = "psutil-5.8.0.tar.gz", hash = "sha256:0c9ccb99ab76025f2f0bbecf341d4656e9c1351db8cc8a03ccd62e318ab4b5c6"}, +] +py = [ + {file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"}, + {file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"}, +] +pycodestyle = [ + {file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"}, + {file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"}, +] +pyflakes = [ + {file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"}, + {file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"}, +] +pyparsing = [ + {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, + {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, +] +pyrsistent = [ + {file = "pyrsistent-0.17.3.tar.gz", hash = "sha256:2e636185d9eb976a18a8a8e96efce62f2905fea90041958d8cc2a189756ebf3e"}, +] +pysen = [ + {file = "pysen-0.9.1-py3-none-any.whl", hash = "sha256:706088b904a74b83a341cc7b19f213737412575bc74e851a57e7b6db80e437c9"}, + {file = "pysen-0.9.1.tar.gz", hash = "sha256:c84953b8eaec7a968e42a89f474ba177665abdf4e051352ec6931a3e96977a41"}, +] +pytest = [ + {file = "pytest-6.2.2-py3-none-any.whl", hash = "sha256:b574b57423e818210672e07ca1fa90aaf194a4f63f3ab909a2c67ebb22913839"}, + {file = "pytest-6.2.2.tar.gz", hash = "sha256:9d1edf9e7d0b84d72ea3dbcdfd22b35fb543a5e8f2a60092dd578936bf63d7f9"}, +] +pyyaml = [ + {file = "PyYAML-5.4.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922"}, + {file = "PyYAML-5.4.1-cp27-cp27m-win32.whl", hash = "sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393"}, + {file = "PyYAML-5.4.1-cp27-cp27m-win_amd64.whl", hash = "sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8"}, + {file = "PyYAML-5.4.1-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185"}, + {file = "PyYAML-5.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253"}, + {file = "PyYAML-5.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc"}, + {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347"}, + {file = "PyYAML-5.4.1-cp36-cp36m-manylinux2014_s390x.whl", hash = "sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541"}, + {file = "PyYAML-5.4.1-cp36-cp36m-win32.whl", hash = "sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5"}, + {file = "PyYAML-5.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df"}, + {file = "PyYAML-5.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018"}, + {file = "PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63"}, + {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa"}, + {file = "PyYAML-5.4.1-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0"}, + {file = "PyYAML-5.4.1-cp37-cp37m-win32.whl", hash = "sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b"}, + {file = "PyYAML-5.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf"}, + {file = "PyYAML-5.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46"}, + {file = "PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb"}, + {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247"}, + {file = "PyYAML-5.4.1-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc"}, + {file = "PyYAML-5.4.1-cp38-cp38-win32.whl", hash = "sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc"}, + {file = "PyYAML-5.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696"}, + {file = "PyYAML-5.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77"}, + {file = "PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183"}, + {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122"}, + {file = "PyYAML-5.4.1-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6"}, + {file = "PyYAML-5.4.1-cp39-cp39-win32.whl", hash = "sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10"}, + {file = "PyYAML-5.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db"}, + {file = "PyYAML-5.4.1.tar.gz", hash = "sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e"}, +] +regex = [ + {file = "regex-2021.4.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:619d71c59a78b84d7f18891fe914446d07edd48dc8328c8e149cbe0929b4e000"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:47bf5bf60cf04d72bf6055ae5927a0bd9016096bf3d742fa50d9bf9f45aa0711"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:281d2fd05555079448537fe108d79eb031b403dac622621c78944c235f3fcf11"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:bd28bc2e3a772acbb07787c6308e00d9626ff89e3bfcdebe87fa5afbfdedf968"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:7c2a1af393fcc09e898beba5dd59196edaa3116191cc7257f9224beaed3e1aa0"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:c38c71df845e2aabb7fb0b920d11a1b5ac8526005e533a8920aea97efb8ec6a4"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:96fcd1888ab4d03adfc9303a7b3c0bd78c5412b2bfbe76db5b56d9eae004907a"}, + {file = "regex-2021.4.4-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:ade17eb5d643b7fead300a1641e9f45401c98eee23763e9ed66a43f92f20b4a7"}, + {file = "regex-2021.4.4-cp36-cp36m-win32.whl", hash = "sha256:e8e5b509d5c2ff12f8418006d5a90e9436766133b564db0abaec92fd27fcee29"}, + {file = "regex-2021.4.4-cp36-cp36m-win_amd64.whl", hash = "sha256:11d773d75fa650cd36f68d7ca936e3c7afaae41b863b8c387a22aaa78d3c5c79"}, + {file = "regex-2021.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d3029c340cfbb3ac0a71798100ccc13b97dddf373a4ae56b6a72cf70dfd53bc8"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:18c071c3eb09c30a264879f0d310d37fe5d3a3111662438889ae2eb6fc570c31"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:4c557a7b470908b1712fe27fb1ef20772b78079808c87d20a90d051660b1d69a"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:01afaf2ec48e196ba91b37451aa353cb7eda77efe518e481707e0515025f0cd5"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:3a9cd17e6e5c7eb328517969e0cb0c3d31fd329298dd0c04af99ebf42e904f82"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:90f11ff637fe8798933fb29f5ae1148c978cccb0452005bf4c69e13db951e765"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:919859aa909429fb5aa9cf8807f6045592c85ef56fdd30a9a3747e513db2536e"}, + {file = "regex-2021.4.4-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:339456e7d8c06dd36a22e451d58ef72cef293112b559010db3d054d5560ef439"}, + {file = "regex-2021.4.4-cp37-cp37m-win32.whl", hash = "sha256:67bdb9702427ceddc6ef3dc382455e90f785af4c13d495f9626861763ee13f9d"}, + {file = "regex-2021.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:32e65442138b7b76dd8173ffa2cf67356b7bc1768851dded39a7a13bf9223da3"}, + {file = "regex-2021.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1e1c20e29358165242928c2de1482fb2cf4ea54a6a6dea2bd7a0e0d8ee321500"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:314d66636c494ed9c148a42731b3834496cc9a2c4251b1661e40936814542b14"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:6d1b01031dedf2503631d0903cb563743f397ccaf6607a5e3b19a3d76fc10480"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:741a9647fcf2e45f3a1cf0e24f5e17febf3efe8d4ba1281dcc3aa0459ef424dc"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:4c46e22a0933dd783467cf32b3516299fb98cfebd895817d685130cc50cd1093"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e512d8ef5ad7b898cdb2d8ee1cb09a8339e4f8be706d27eaa180c2f177248a10"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:980d7be47c84979d9136328d882f67ec5e50008681d94ecc8afa8a65ed1f4a6f"}, + {file = "regex-2021.4.4-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:ce15b6d103daff8e9fee13cf7f0add05245a05d866e73926c358e871221eae87"}, + {file = "regex-2021.4.4-cp38-cp38-win32.whl", hash = "sha256:a91aa8619b23b79bcbeb37abe286f2f408d2f2d6f29a17237afda55bb54e7aac"}, + {file = "regex-2021.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:c0502c0fadef0d23b128605d69b58edb2c681c25d44574fc673b0e52dce71ee2"}, + {file = "regex-2021.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:598585c9f0af8374c28edd609eb291b5726d7cbce16be6a8b95aa074d252ee17"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:ee54ff27bf0afaf4c3b3a62bcd016c12c3fdb4ec4f413391a90bd38bc3624605"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7d9884d86dd4dd489e981d94a65cd30d6f07203d90e98f6f657f05170f6324c9"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:bf5824bfac591ddb2c1f0a5f4ab72da28994548c708d2191e3b87dd207eb3ad7"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:563085e55b0d4fb8f746f6a335893bda5c2cef43b2f0258fe1020ab1dd874df8"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b9c3db21af35e3b3c05764461b262d6f05bbca08a71a7849fd79d47ba7bc33ed"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:3916d08be28a1149fb97f7728fca1f7c15d309a9f9682d89d79db75d5e52091c"}, + {file = "regex-2021.4.4-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:fd45ff9293d9274c5008a2054ecef86a9bfe819a67c7be1afb65e69b405b3042"}, + {file = "regex-2021.4.4-cp39-cp39-win32.whl", hash = "sha256:fa4537fb4a98fe8fde99626e4681cc644bdcf2a795038533f9f711513a862ae6"}, + {file = "regex-2021.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:97f29f57d5b84e73fbaf99ab3e26134e6687348e95ef6b48cfd2c06807005a07"}, + {file = "regex-2021.4.4.tar.gz", hash = "sha256:52ba3d3f9b942c49d7e4bc105bb28551c44065f139a65062ab7912bef10c9afb"}, +] +requests = [ + {file = "requests-2.25.1-py2.py3-none-any.whl", hash = "sha256:c210084e36a42ae6b9219e00e48287def368a26d03a048ddad7bfee44f75871e"}, + {file = "requests-2.25.1.tar.gz", hash = "sha256:27973dd4a904a4f13b263a19c866c13b92a39ed1c964655f025f3f8d3d75b804"}, +] +six = [ + {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"}, + {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"}, +] +smmap = [ + {file = "smmap-4.0.0-py2.py3-none-any.whl", hash = "sha256:a9a7479e4c572e2e775c404dcd3080c8dc49f39918c2cf74913d30c4c478e3c2"}, + {file = "smmap-4.0.0.tar.gz", hash = "sha256:7e65386bd122d45405ddf795637b7f7d2b532e7e401d46bbe3fb49b9986d5182"}, +] +toml = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] +tomlkit = [ + {file = "tomlkit-0.7.0-py2.py3-none-any.whl", hash = "sha256:6babbd33b17d5c9691896b0e68159215a9387ebfa938aa3ac42f4a4beeb2b831"}, + {file = "tomlkit-0.7.0.tar.gz", hash = "sha256:ac57f29693fab3e309ea789252fcce3061e19110085aa31af5446ca749325618"}, +] +torch = [ + {file = "torch-1.8.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:f23eeb1a48cc39209d986c418ad7e02227eee973da45c0c42d36b1aec72f4940"}, + {file = "torch-1.8.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4ace9c5bb94d5a7b9582cd089993201658466e9c59ff88bd4e9e08f6f072d1cf"}, + {file = "torch-1.8.1-cp36-cp36m-win_amd64.whl", hash = "sha256:6ffa1e7ae079c7cb828712cb0cdaae5cc4fb87c16a607e6d14526b62c20bcc17"}, + {file = "torch-1.8.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:16f2630d9604c4ee28ea7d6e388e2264cd7bc6031c6ecd796bae3f56b5efa9a3"}, + {file = "torch-1.8.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95b7bbbacc3f28fe438f418392ceeae146a01adc03b29d44917d55214ac234c9"}, + {file = "torch-1.8.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:55137feb2f5a0dc7aced5bba690dcdb7652054ad3452b09a2bbb59f02a11e9ff"}, + {file = "torch-1.8.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8ad2252bf09833dcf46a536a78544e349b8256a370e03a98627ebfb118d9555b"}, + {file = "torch-1.8.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:1388b30fbd262c1a053d6c9ace73bb0bd8f5871b4892b6f3e02d1d7bc9768563"}, + {file = "torch-1.8.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:e7ad1649adb7dc2a450e70a3e51240b84fa4746c69c8f98989ce0c254f9fba3a"}, + {file = "torch-1.8.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:3e4190c04dfd89c59bad06d5fe451446643a65e6d2607cc989eb1001ee76e12f"}, + {file = "torch-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:5c2e9a33d44cdb93ebd739b127ffd7da786bf5f740539539195195b186a05f6c"}, + {file = "torch-1.8.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:c6ede2ae4dcd8214b63e047efabafa92493605205a947574cf358216ca4e440a"}, + {file = "torch-1.8.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ce7d435426f3dd14f95710d779aa46e9cd5e077d512488e813f7589fdc024f78"}, + {file = "torch-1.8.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a50ea8ed900927fb30cadb63aa7a32fdd59c7d7abe5012348dfbe35a8355c083"}, + {file = "torch-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:dac4d10494e74f7e553c92d7263e19ea501742c4825ddd26c4decfa27be95981"}, + {file = "torch-1.8.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:225ee4238c019b28369c71977327deeeb2bd1c6b8557e6fcf631b8866bdc5447"}, +] +typed-ast = [ + {file = "typed_ast-1.4.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:2068531575a125b87a41802130fa7e29f26c09a2833fea68d9a40cf33902eba6"}, + {file = "typed_ast-1.4.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:c907f561b1e83e93fad565bac5ba9c22d96a54e7ea0267c708bffe863cbe4075"}, + {file = "typed_ast-1.4.3-cp35-cp35m-manylinux2014_aarch64.whl", hash = "sha256:1b3ead4a96c9101bef08f9f7d1217c096f31667617b58de957f690c92378b528"}, + {file = "typed_ast-1.4.3-cp35-cp35m-win32.whl", hash = "sha256:dde816ca9dac1d9c01dd504ea5967821606f02e510438120091b84e852367428"}, + {file = "typed_ast-1.4.3-cp35-cp35m-win_amd64.whl", hash = "sha256:777a26c84bea6cd934422ac2e3b78863a37017618b6e5c08f92ef69853e765d3"}, + {file = "typed_ast-1.4.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f8afcf15cc511ada719a88e013cec87c11aff7b91f019295eb4530f96fe5ef2f"}, + {file = "typed_ast-1.4.3-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:52b1eb8c83f178ab787f3a4283f68258525f8d70f778a2f6dd54d3b5e5fb4341"}, + {file = "typed_ast-1.4.3-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:01ae5f73431d21eead5015997ab41afa53aa1fbe252f9da060be5dad2c730ace"}, + {file = "typed_ast-1.4.3-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:c190f0899e9f9f8b6b7863debfb739abcb21a5c054f911ca3596d12b8a4c4c7f"}, + {file = "typed_ast-1.4.3-cp36-cp36m-win32.whl", hash = "sha256:398e44cd480f4d2b7ee8d98385ca104e35c81525dd98c519acff1b79bdaac363"}, + {file = "typed_ast-1.4.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bff6ad71c81b3bba8fa35f0f1921fb24ff4476235a6e94a26ada2e54370e6da7"}, + {file = "typed_ast-1.4.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0fb71b8c643187d7492c1f8352f2c15b4c4af3f6338f21681d3681b3dc31a266"}, + {file = "typed_ast-1.4.3-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:760ad187b1041a154f0e4d0f6aae3e40fdb51d6de16e5c99aedadd9246450e9e"}, + {file = "typed_ast-1.4.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5feca99c17af94057417d744607b82dd0a664fd5e4ca98061480fd8b14b18d04"}, + {file = "typed_ast-1.4.3-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:95431a26309a21874005845c21118c83991c63ea800dd44843e42a916aec5899"}, + {file = "typed_ast-1.4.3-cp37-cp37m-win32.whl", hash = "sha256:aee0c1256be6c07bd3e1263ff920c325b59849dc95392a05f258bb9b259cf39c"}, + {file = "typed_ast-1.4.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9ad2c92ec681e02baf81fdfa056fe0d818645efa9af1f1cd5fd6f1bd2bdfd805"}, + {file = "typed_ast-1.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b36b4f3920103a25e1d5d024d155c504080959582b928e91cb608a65c3a49e1a"}, + {file = "typed_ast-1.4.3-cp38-cp38-manylinux1_i686.whl", hash = "sha256:067a74454df670dcaa4e59349a2e5c81e567d8d65458d480a5b3dfecec08c5ff"}, + {file = "typed_ast-1.4.3-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7538e495704e2ccda9b234b82423a4038f324f3a10c43bc088a1636180f11a41"}, + {file = "typed_ast-1.4.3-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:af3d4a73793725138d6b334d9d247ce7e5f084d96284ed23f22ee626a7b88e39"}, + {file = "typed_ast-1.4.3-cp38-cp38-win32.whl", hash = "sha256:f2362f3cb0f3172c42938946dbc5b7843c2a28aec307c49100c8b38764eb6927"}, + {file = "typed_ast-1.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:dd4a21253f42b8d2b48410cb31fe501d32f8b9fbeb1f55063ad102fe9c425e40"}, + {file = "typed_ast-1.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f328adcfebed9f11301eaedfa48e15bdece9b519fb27e6a8c01aa52a17ec31b3"}, + {file = "typed_ast-1.4.3-cp39-cp39-manylinux1_i686.whl", hash = "sha256:2c726c276d09fc5c414693a2de063f521052d9ea7c240ce553316f70656c84d4"}, + {file = "typed_ast-1.4.3-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cae53c389825d3b46fb37538441f75d6aecc4174f615d048321b716df2757fb0"}, + {file = "typed_ast-1.4.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b9574c6f03f685070d859e75c7f9eeca02d6933273b5e69572e5ff9d5e3931c3"}, + {file = "typed_ast-1.4.3-cp39-cp39-win32.whl", hash = "sha256:209596a4ec71d990d71d5e0d312ac935d86930e6eecff6ccc7007fe54d703808"}, + {file = "typed_ast-1.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:9c6d1a54552b5330bc657b7ef0eae25d00ba7ffe85d9ea8ae6540d2197a3788c"}, + {file = "typed_ast-1.4.3.tar.gz", hash = "sha256:fb1bbeac803adea29cedd70781399c99138358c26d05fcbd23c13016b7f5ec65"}, +] +typing-extensions = [ + {file = "typing_extensions-3.7.4.3-py2-none-any.whl", hash = "sha256:dafc7639cde7f1b6e1acc0f457842a83e722ccca8eef5270af2d74792619a89f"}, + {file = "typing_extensions-3.7.4.3-py3-none-any.whl", hash = "sha256:7cb407020f00f7bfc3cb3e7881628838e69d8f3fcab2f64742a5e76b2f841918"}, + {file = "typing_extensions-3.7.4.3.tar.gz", hash = "sha256:99d4073b617d30288f569d3f13d2bd7548c3a7e4c8de87db09a9d29bb3a4a60c"}, +] +unidiff = [ + {file = "unidiff-0.6.0-py2.py3-none-any.whl", hash = "sha256:e1dd956a492ccc4351e24931b2f2d29c79e3be17a99dd8f14e95324321d93a88"}, + {file = "unidiff-0.6.0.tar.gz", hash = "sha256:90c5214e9a357ff4b2fee19d91e77706638e3e00592a732d9405ea4e93da981f"}, +] +urllib3 = [ + {file = "urllib3-1.26.4-py2.py3-none-any.whl", hash = "sha256:2f4da4594db7e1e110a944bb1b551fdf4e6c136ad42e4234131391e21eb5b0df"}, + {file = "urllib3-1.26.4.tar.gz", hash = "sha256:e7b021f7241115872f92f43c6508082facffbd1c048e3c6e2bb9c2a157e28937"}, +] +zipp = [ + {file = "zipp-3.4.1-py3-none-any.whl", hash = "sha256:51cb66cc54621609dd593d1787f286ee42a5c0adbb4b29abea5a63edc3e03098"}, + {file = "zipp-3.4.1.tar.gz", hash = "sha256:3607921face881ba3e026887d8150cca609d517579abe052ac81fc5aeffdbd76"}, +] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..01d5632b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["poetry-core>=1.1.5"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "handyrl" +# Bump it to appropreate version if #183 done and ready to release. +version = "0.0.1" +description = "Simple framework of distributed reinforcement learning on Python/PyTorch" +authors = ["DeNA Co., Ltd."] +license = "MIT" + +[tool.poetry.dependencies] +python = "^3.7" + +PyYAML = "^5.4.1" +numpy = "^1.20.2" +psutil = "^5.8.0" +torch = "^1.8.1" + +[tool.poetry.dev-dependencies] +pysen = {version = "^0.9.1", extras = ["lint"]} +pytest = "^6.2.2" +# For handyrl/envs/kaggle/hungry_geese.py +kaggle-environments = "^1.7.11" +requests = "^2.25.1" + +[tool.pysen.lint] +enable_black = true +enable_flake8 = true +enable_isort = true +enable_mypy = false +# mypy_preset = "strict" +line_length = 128 +py_version = "py38" + +[[tool.pysen.lint.mypy_targets]] +paths = ["."] + +[tool.pytest.ini_options] +testpaths = [ + "tests", +] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..ef218fb6 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,48 @@ +appdirs==1.4.4; python_version >= "3.6" +atomicwrites==1.4.0; python_version >= "3.6" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.6" and python_full_version >= "3.4.0" +attrs==20.3.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" +black==20.8b1; python_version >= "3.6" +certifi==2020.12.5; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +chardet==4.0.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +click==7.1.2; python_version >= "3.6" and python_full_version < "3.0.0" or python_version >= "3.6" and python_full_version >= "3.5.0" +colorama==0.4.4; python_version >= "3.6" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.6" and python_full_version >= "3.5.0" +colorlog==4.8.0 +dacite==1.6.0; python_version >= "3.6" +flake8-bugbear==21.4.3; python_version >= "3.6" +flake8==3.9.1; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6" +gitdb==4.0.7; python_version >= "3.4" +gitpython==3.1.14; python_version >= "3.4" +idna==2.10; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +importlib-metadata==3.10.0; python_version < "3.8" and python_version >= "3.6" and (python_version >= "3.6" and python_full_version < "3.0.0" and python_version < "3.8" or python_full_version >= "3.5.0" and python_version < "3.8" and python_version >= "3.6") and (python_version >= "3.6" and python_full_version < "3.0.0" and python_version < "3.8" or python_full_version >= "3.4.0" and python_version >= "3.6" and python_version < "3.8") +iniconfig==1.1.1; python_version >= "3.6" +isort==5.1.4; python_version >= "3.6" and python_version < "4.0" +jsonschema==3.2.0; python_version >= "3.6" +kaggle-environments==1.7.11; python_version >= "3.6" +mccabe==0.6.1; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +mypy-extensions==0.4.3; python_version >= "3.6" +mypy==0.790; python_version >= "3.5" +numpy==1.20.2; python_version >= "3.7" +packaging==20.9; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" +pathspec==0.8.1; python_version >= "3.6" and python_full_version < "3.0.0" or python_version >= "3.6" and python_full_version >= "3.5.0" +pluggy==0.13.1; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" +psutil==5.8.0; (python_version >= "2.6" and python_full_version < "3.0.0") or (python_full_version >= "3.4.0") +py==1.10.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" +pycodestyle==2.7.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +pyflakes==2.3.1; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +pyparsing==2.4.7; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6" +pyrsistent==0.17.3; python_version >= "3.6" +pysen==0.9.1 +pytest==6.2.2; python_version >= "3.6" +pyyaml==5.4.1; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.6.0") +regex==2021.4.4; python_version >= "3.6" +requests==2.25.1; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.5.0") +six==1.15.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.6" +smmap==4.0.0; python_version >= "3.5" +toml==0.10.2; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.6" +tomlkit==0.7.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +torch==1.8.1; python_full_version >= "3.6.2" +typed-ast==1.4.3; python_version >= "3.6" +typing-extensions==3.7.4.3; python_version >= "3.6" and python_full_version >= "3.6.2" and python_version < "3.8" +unidiff==0.6.0 +urllib3==1.26.4; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version < "4" +zipp==3.4.1; python_version < "3.8" and python_version >= "3.6" diff --git a/requirements.txt b/requirements.txt old mode 100755 new mode 100644 index 90443c01..1405f171 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -PyYAML -numpy -torch -pytest -psutil +numpy==1.20.2; python_version >= "3.7" +psutil==5.8.0; (python_version >= "2.6" and python_full_version < "3.0.0") or (python_full_version >= "3.4.0") +pyyaml==5.4.1; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.6.0") +torch==1.8.1; python_full_version >= "3.6.2" +typing-extensions==3.7.4.3; python_full_version >= "3.6.2" diff --git a/tests/test_environment.py b/tests/test_environment.py index c137db3d..5925fa49 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,43 +1,39 @@ import importlib -import pytest import random import traceback +import pytest ENVS = [ - 'tictactoe', - 'geister', - 'parallel_tictactoe', - 'kaggle.hungry_geese', + "tictactoe", + "geister", + "parallel_tictactoe", + "kaggle.hungry_geese", ] @pytest.fixture def environment_path(): - return 'handyrl.envs' + return "handyrl.envs" -@pytest.mark.parametrize('env', ENVS) +@pytest.mark.parametrize("env", ENVS) def test_environment_property(environment_path, env): """Test properties of environment""" - try: - env_path = '.'.join([environment_path, env]) - env_module = importlib.import_module(env_path) - e = env_module.Environment() - e.players() - e.action_length() - str(e) - except Exception: - traceback.print_exc() - assert False + env_path = ".".join([environment_path, env]) + env_module = importlib.import_module(env_path) + e = env_module.Environment() + e.players() + e.action_length() + str(e) -@pytest.mark.parametrize('env', ENVS) +@pytest.mark.parametrize("env", ENVS) def test_environment_local(environment_path, env): """Test battle loop using local battle interface of environment""" no_error_loop = False try: - env_path = '.'.join([environment_path, env]) + env_path = ".".join([environment_path, env]) env_module = importlib.import_module(env_path) e = env_module.Environment() for _ in range(100): @@ -56,12 +52,12 @@ def test_environment_local(environment_path, env): assert no_error_loop -@pytest.mark.parametrize('env', ENVS) +@pytest.mark.parametrize("env", ENVS) def test_environment_network(environment_path, env): """Test battle loop using network battle interface of environment""" no_error_loop = False try: - env_path = '.'.join([environment_path, env]) + env_path = ".".join([environment_path, env]) env_module = importlib.import_module(env_path) e = env_module.Environment() es = {p: env_module.Environment() for p in e.players()}