Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update A_B.ipynb #161

Open
wants to merge 3 commits into
base: dmarx.ab_notebook
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 102 additions & 89 deletions nbs/A_B.ipynb
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "i0RWQRJAkdJe"
},
"source": [
"# Stability.AI A/B Testing Notebook\n",
"\n",
Expand All @@ -38,10 +27,7 @@
" - Click again to deselect if you cahnge your mind\n",
" - The notebook does not currently constrain the user to only select one option, but that's how we recommend you use it. \n",
" - When you're satisfied with your selection, execute the cell again to log your feedback and generate a new set of images."
],
"metadata": {
"id": "i0RWQRJAkdJe"
}
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -242,12 +228,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GX85BLyFrGlJ"
},
"outputs": [],
"source": [
"%%writefile test_config.yaml\n",
"\n",
"### settings that will be used across test cases.\n",
"defaults:\n",
" grpc_host: grpc.stability.ai:443\n",
" host: grpc.stability.ai:443\n",
" # If API key not provided in test_config.yaml, user prompted with getpass\n",
" key:\n",
"\n",
Expand Down Expand Up @@ -285,25 +276,29 @@
" middle: ''\n",
" # Don't do this, results in `middle:\"None\"`\n",
" # middle:\n"
],
"metadata": {
"id": "GX85BLyFrGlJ"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xl_t3CjZlWvq"
},
"outputs": [],
"source": [
"# @markdown ## Load Experiments\n",
"\n",
"from omegaconf import OmegaConf\n",
"import getpass\n",
"from stability_sdk import client\n",
"\n",
"import shutil\n",
"import panel as pn\n",
"pn.extension()\n",
"\n",
"!mkdir fav\n",
"!mkdir results\n",
"\n",
"\n",
"exp_cfg_fpath_out = Path(workspace_cfg.project_root) / workspace_cfg.exp_cfg_fname\n",
"exp_cfg_fpath = exp_cfg_fpath_out\n",
Expand Down Expand Up @@ -342,7 +337,7 @@
"#####################################\n",
"\n",
"required_attributes = [\n",
" 'grpc_host',\n",
" 'host',\n",
" #'api_key'\n",
" 'key',\n",
"]\n",
Expand Down Expand Up @@ -379,76 +374,35 @@
"for test_case in cfg.differentiators:\n",
" running_score[test_case]+=0\n",
"\n"
],
"metadata": {
"id": "xl_t3CjZlWvq",
"cellView": "form"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ogPWQhWm7BXN"
},
"outputs": [],
"source": [
"# @markdown # Load a random sample to score preference\n",
"\n",
"SAMPLE_IDX += 1\n",
"\n",
"Copy_to_path = \"/content/results/\"\n",
"##########################\n",
"# Log experiment outcome #\n",
"##########################\n",
"\n",
"save_images = False # @param {type:'boolean'}\n",
"save_favorite_only = False # @param {type:'boolean'}\n",
"save_favorite_only = True # @param {type:'boolean'}\n",
"\n",
"\n",
"\n",
"# to do: make this not a closure.\n",
"def log_items(items):\n",
" #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify\n",
" recs = []\n",
" for item in items:\n",
" # assign image a filename\n",
" img_fname = f\"{RANDOM_PREFIX}_{SAMPLE_IDX}_{item['test_case']}.png\"\n",
" #rec = copy.deepcopy(item)\n",
" rec = item\n",
" img_fpath = Path(workspace_cfg.project_root) / img_fname\n",
" # save image\n",
" img = rec.pop('img')\n",
" save_im = False\n",
" if save_images or save_favorite_only:\n",
" save_im = True\n",
" if save_favorite_only and not rec['is_preference']:\n",
" save_im = False\n",
" if save_im:\n",
" print(img_fpath)\n",
" rec['img_fpath'] = str(img_fpath)\n",
" img.save(img_fpath)\n",
" # update outcome\n",
" rec['is_preference'] = rec['button'].value\n",
" if rec['is_preference']:\n",
" running_score[rec['test_case']] += 1\n",
" rec.pop('button')\n",
" # log outcome\n",
" recs.append(rec)\n",
" outfile = Path(workspace_cfg.project_root) / explog_fname\n",
" #with open(outfile, 'a') as f:\n",
" with outfile.open('a') as f:\n",
" json.dump(recs, f)\n",
" f.write('\\n')\n",
" logger.debug(running_score)\n",
" \n",
"if items:\n",
" try:\n",
" log_items(items)\n",
" posterior_plot(running_score)\n",
" except KeyError:\n",
" # fuck it\n",
" pass\n",
"\n",
"\n",
"SEED = random.randrange(0, 4294967295)\n",
"\n",
"blind_test = False # @param {type: \"boolean\"}\n",
"blind_test = True # @param {type: \"boolean\"}\n",
"\n",
"def item_to_ux(\n",
" item\n",
Expand All @@ -469,10 +423,11 @@
" output.append(toggle)\n",
" item['button'] = toggle\n",
" item['is_preference'] = toggle.value\n",
" print(item)\n",
" return pn.Column(*output)\n",
"\n",
"\n",
"non_generation_arguments = ['grpc_host', 'engine', 'key']\n",
"non_generation_arguments = ['host', 'engine', 'key']\n",
"\n",
"rec = random.choice(experiments)\n",
"\n",
Expand Down Expand Up @@ -508,6 +463,10 @@
" if artifact.type == generation.ARTIFACT_IMAGE:\n",
" img = Image.open(io.BytesIO(artifact.binary))\n",
" img = img.resize([512, 512])\n",
" if save_images or save_favorite_only:\n",
" img_fname = f\"{kwargs_gen['prompt']}_{kwargs_gen['seed']}_{SAMPLE_IDX}_{test_case}.png\"\n",
" img.save(Copy_to_path+img_fname)\n",
"\n",
" items.append({\n",
" 'img':img,\n",
" 'test_case':test_case,\n",
Expand All @@ -517,18 +476,72 @@
" 'SDK_VERSION':SDK_VERSION,\n",
" 'timestamp':time.time(),\n",
" 'user_id': workspace_cfg.notebook_user_id,\n",
" 'project_name':workspace_cfg.active_project,\n",
" 'project_name':workspace_cfg.active_project, \n",
" })\n",
"\n",
"random.shuffle(items)\n",
"pn.Row(*[item_to_ux(it) for it in items])"
],
"metadata": {
"id": "ogPWQhWm7BXN",
"cellView": "form"
},
"pn.Row(*[item_to_ux(it) for it in items]) "
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": []
"metadata": {},
"outputs": [],
"source": [
"# to do: make this not a closure.\n",
"import shutil\n",
"def log_items(items):\n",
" #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify\n",
" recs = []\n",
" for item in items:\n",
" # assign image a filename\n",
" # rec = copy.deepcopy(item)\n",
" rec = item\n",
" # save image\n",
" rec['is_preference'] = rec['button'].value\n",
" if save_favorite_only:\n",
" if rec['is_preference']:\n",
" img_fname1 = f\"{rec['kwargs']['prompt']}_{rec['kwargs']['seed']}_{SAMPLE_IDX}_{rec['test_case']}.png\"\n",
" print(type(Copy_to_path+img_fname1))\n",
" shutil.move(Copy_to_path+img_fname1, '/content/fav')\n",
" # update outcome\n",
" if rec['is_preference']:\n",
" running_score[rec['test_case']] += 1\n",
" rec.pop('img')\n",
" rec.pop('button')\n",
" # log outcome\n",
" recs.append(rec)\n",
" outfile = Path(workspace_cfg.project_root) / explog_fname\n",
" #with open(outfile, 'a') as f:\n",
" with outfile.open('a') as f:\n",
" json.dump(recs, f)\n",
" f.write('\\n')\n",
" logger.debug(running_score)\n",
"\n",
"# print(items)\n",
"if items:\n",
" try:\n",
" log_items(items)\n",
" posterior_plot(running_score)\n",
" except KeyError:\n",
" # fuck it\n",
" pass"
]
}
]
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}