Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce pysen #186

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions .github/workflows/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,61 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.7, 3.8, 3.9]
poetry-version: [1.1.5]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r handyrl/envs/kaggle/requirements.txt
- uses: abatilo/[email protected]
with:
poetry-version: ${{ matrix.poetry-version }}
- run: |
poetry install
- name: pytest
run: |
python -m pytest tests
poetry run python -m pytest -v
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
poetry-version: [1.1.5]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- uses: abatilo/[email protected]
with:
poetry-version: ${{ matrix.poetry-version }}
- run: |
poetry install
- run: |
poetry run pysen run lint
validate-requirements:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
poetry-version: [1.1.5]
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- uses: abatilo/[email protected]
with:
poetry-version: ${{ matrix.poetry-version }}
- run: |
# If failed, do
# $ poetry export --without-hashes -f requirements.txt --output requirements.txt
# $ poetry export --without-hashes -f requirements.txt --output requirements-dev.txt --dev
diff <(poetry export --without-hashes -f requirements.txt | grep -v '^Warning:') requirements.txt
diff <(poetry export --without-hashes -f requirements.txt --dev | grep -v '^Warning:') requirements-dev.txt
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ pip3 install -r requirements.txt

To use games of kaggle environments (e.g. Hungry Geese) you can install also additional dependencies.
```
pip3 install -r handyrl/envs/kaggle/requirements.txt
pip3 install -r requirements-dev.txt
```

Or equivalently, you can use [poetry](https://python-poetry.org/).

```
poetry install
```


Expand Down Expand Up @@ -115,3 +121,35 @@ NOTE: Default opponent AI is random agent implemented in `evaluation.py`. You ca

* [Month 1 Winner in Hungry Geese (Kaggle)](https://www.kaggle.com/c/hungry-geese/discussion/222941)
* [The 5th solution in Google Research Football with Manchester City F.C. (Kaggle)](https://www.kaggle.com/c/google-football/discussion/203412)


## How to develop HandyRL

### Lint

```console
poetry run pysen run lint
```

You can fix some errors automatically:

```console
poetry run pysen run foramt
```

### Test

```console
poetry run pysen run pytest -v
```

### Tips

If you do not use IDE, the following command helps you:

```console
# At the first time, install `inotifywait`.
sudo apt-get install -y inotify-tools

function COMMAND { clear; poetry run pysen run lint && poetry run pytest -v }; COMMAND; while true; do if inotifywait -r -e modify . 2>/dev/null | git check-ignore --stdin >/dev/null; then :; else COMMAND; fi; done
```
22 changes: 11 additions & 11 deletions handyrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,32 @@ def observe(self, env, player, show=False):

class RuleBasedAgent(RandomAgent):
def action(self, env, player, show=False):
if hasattr(env, 'rule_based_action'):
if hasattr(env, "rule_based_action"):
return env.rule_based_action(player)
else:
return random.choice(env.legal_actions(player))


def view(env, player=None):
if hasattr(env, 'view'):
if hasattr(env, "view"):
env.view(player=player)
else:
print(env)


def view_transition(env):
if hasattr(env, 'view_transition'):
if hasattr(env, "view_transition"):
env.view_transition()
else:
pass


def print_outputs(env, prob, v):
if hasattr(env, 'print_outputs'):
if hasattr(env, "print_outputs"):
env.print_outputs(prob, v)
else:
print('v = %f' % v)
print('p = %s' % (prob * 1000).astype(int))
print("v = %f" % v)
print("p = %s" % (prob * 1000).astype(int))


class Agent:
Expand All @@ -65,14 +65,14 @@ def reset(self, env, show=False):

def plan(self, obs):
outputs = self.model.inference(obs, self.hidden)
self.hidden = outputs.pop('hidden', None)
self.hidden = outputs.pop("hidden", None)
return outputs

def action(self, env, player, show=False):
outputs = self.plan(env.observation(player))
actions = env.legal_actions(player)
p = outputs['policy']
v = outputs.get('value', None)
p = outputs["policy"]
v = outputs.get("value", None)
mask = np.ones_like(p)
mask[actions] = 0
p -= mask * 1e32
Expand All @@ -90,7 +90,7 @@ def action(self, env, player, show=False):
def observe(self, env, player, show=False):
if self.observation:
outputs = self.plan(env.observation(player))
v = outputs.get('value', None)
v = outputs.get("value", None)
if show:
view(env, player=player)
if self.observation:
Expand All @@ -106,7 +106,7 @@ def plan(self, obs):
for i, model in enumerate(self.model):
o = model.inference(obs, self.hidden[i])
for k, v in o:
if k == 'hidden':
if k == "hidden":
self.hidden[i] = v
else:
outputs[k] = outputs.get(k, []) + [o]
Expand Down
37 changes: 18 additions & 19 deletions handyrl/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# Licensed under The MIT License [see LICENSE for details]

import io
import time
import struct
import socket
import pickle
import threading
import queue
import multiprocessing as mp
import multiprocessing.connection as connection
import pickle
import queue
import socket
import struct
import threading
from typing import Any, List, Optional


def send_recv(conn, sdata):
Expand Down Expand Up @@ -45,7 +45,7 @@ def _recv(self, size):

def recv(self):
buf = self._recv(4)
size, = struct.unpack("!i", buf.getvalue())
(size,) = struct.unpack("!i", buf.getvalue())
buf = self._recv(size)
return pickle.loads(buf.getvalue())

Expand All @@ -72,11 +72,8 @@ def send(self, msg):

def open_socket_connection(port, reuse=False):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR,
sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1
)
sock.bind(('', int(port)))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1)
sock.bind(("", int(port)))
return sock


Expand All @@ -99,7 +96,7 @@ def connect_socket_connection(host, port):
try:
sock.connect((host, int(port)))
except ConnectionRefusedError:
print('failed to connect %s %d' % (host, port))
print("failed to connect %s %d" % (host, port))
return PickledConnection(sock)


Expand Down Expand Up @@ -165,7 +162,7 @@ def start(self):
thread.start()

def _sender(self):
print('start sender')
print("start sender")
while not self.shutdown_flag:
data = next(self.send_generator)
while not self.shutdown_flag:
Expand All @@ -175,10 +172,10 @@ def _sender(self):
break
except queue.Empty:
pass
print('finished sender')
print("finished sender")

def _receiver(self, index):
print('start receiver %d' % index)
print("start receiver %d" % index)
conns = [conn for i, conn in enumerate(self.conns) if i % self.num_receivers == index]
while not self.shutdown_flag:
tmp_conns = connection.wait(conns)
Expand All @@ -193,11 +190,13 @@ def _receiver(self, index):
break
except queue.Full:
pass
print('finished receiver %d' % index)
print("finished receiver %d" % index)


class QueueCommunicator:
def __init__(self, conns=[]):
def __init__(self, conns: Optional[List[Any]] = None):
conns = [] if conns is None else conns

self.input_queue = queue.Queue(maxsize=256)
self.output_queue = queue.Queue(maxsize=256)
self.conns = {}
Expand Down Expand Up @@ -228,7 +227,7 @@ def add_connection(self, conn):
self.conn_index += 1

def disconnect(self, conn):
print('disconnected')
print("disconnected")
self.conns.pop(conn, None)

def _send_thread(self):
Expand Down
25 changes: 13 additions & 12 deletions handyrl/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,29 @@
# game environment

import importlib

from typing import Any, Dict, Optional

ENVS = {
'TicTacToe': 'handyrl.envs.tictactoe',
'Geister': 'handyrl.envs.geister',
'ParallelTicTacToe': 'handyrl.envs.parallel_tictactoe',
'HungryGeese': 'handyrl.envs.kaggle.hungry_geese',
"TicTacToe": "handyrl.envs.tictactoe",
"Geister": "handyrl.envs.geister",
"ParallelTicTacToe": "handyrl.envs.parallel_tictactoe",
"HungryGeese": "handyrl.envs.kaggle.hungry_geese",
}


def prepare_env(env_args):
env_name = env_args['env']
env_name = env_args["env"]
env_source = ENVS.get(env_name, env_name)
env_module = importlib.import_module(env_source)

if env_module is None:
print("No environment %s" % env_name)
elif hasattr(env_module, 'prepare'):
elif hasattr(env_module, "prepare"):
env_module.prepare()


def make_env(env_args):
env_name = env_args['env']
env_name = env_args["env"]
env_source = ENVS.get(env_name, env_name)
env_module = importlib.import_module(env_source)

Expand All @@ -38,17 +38,18 @@ def make_env(env_args):

# base class of Environment


class BaseEnvironment:
def __init__(self, args={}):
def __init__(self, args: Optional[Dict[Any, Any]] = None):
pass

def __str__(self):
return ''
return ""

#
# Should be defined in all games
#
def reset(self, args={}):
def reset(self, args: Optional[Dict[Any, Any]]):
raise NotImplementedError()

#
Expand Down Expand Up @@ -135,7 +136,7 @@ def str2action(self, s, player=None):
# Should be defined if you use network battle mode
#
def diff_info(self, player=None):
return ''
return ""

#
# Should be defined if you use network battle mode
Expand Down
Loading