diff --git a/nbs/A_B.ipynb b/nbs/A_B.ipynb index 1e334539..abba7f04 100644 --- a/nbs/A_B.ipynb +++ b/nbs/A_B.ipynb @@ -1,21 +1,10 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, "cells": [ { "cell_type": "markdown", + "metadata": { + "id": "i0RWQRJAkdJe" + }, "source": [ "# Stability.AI A/B Testing Notebook\n", "\n", @@ -38,10 +27,7 @@ " - Click again to deselect if you cahnge your mind\n", " - The notebook does not currently constrain the user to only select one option, but that's how we recommend you use it. \n", " - When you're satisfied with your selection, execute the cell again to log your feedback and generate a new set of images." - ], - "metadata": { - "id": "i0RWQRJAkdJe" - } + ] }, { "cell_type": "code", @@ -242,12 +228,17 @@ }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GX85BLyFrGlJ" + }, + "outputs": [], "source": [ "%%writefile test_config.yaml\n", "\n", "### settings that will be used across test cases.\n", "defaults:\n", - " grpc_host: grpc.stability.ai:443\n", + " host: grpc.stability.ai:443\n", " # If API key not provided in test_config.yaml, user prompted with getpass\n", " key:\n", "\n", @@ -285,25 +276,29 @@ " middle: ''\n", " # Don't do this, results in `middle:\"None\"`\n", " # middle:\n" - ], - "metadata": { - "id": "GX85BLyFrGlJ" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "xl_t3CjZlWvq" + }, + "outputs": [], "source": [ "# @markdown ## Load Experiments\n", "\n", "from omegaconf import OmegaConf\n", "import getpass\n", "from stability_sdk import client\n", - "\n", + "import shutil\n", "import panel as pn\n", "pn.extension()\n", "\n", + "!mkdir fav\n", + "!mkdir results\n", + "\n", "\n", "exp_cfg_fpath_out = Path(workspace_cfg.project_root) / workspace_cfg.exp_cfg_fname\n", "exp_cfg_fpath = exp_cfg_fpath_out\n", @@ -342,7 +337,7 @@ "#####################################\n", "\n", "required_attributes = [\n", - " 'grpc_host',\n", + " 'host',\n", " #'api_key'\n", " 'key',\n", "]\n", @@ -379,76 +374,35 @@ "for test_case in cfg.differentiators:\n", " running_score[test_case]+=0\n", "\n" - ], - "metadata": { - "id": "xl_t3CjZlWvq", - "cellView": "form" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ogPWQhWm7BXN" + }, + "outputs": [], "source": [ "# @markdown # Load a random sample to score preference\n", "\n", "SAMPLE_IDX += 1\n", - "\n", + "Copy_to_path = \"/content/results/\"\n", "##########################\n", "# Log experiment outcome #\n", "##########################\n", "\n", "save_images = False # @param {type:'boolean'}\n", - "save_favorite_only = False # @param {type:'boolean'}\n", + "save_favorite_only = True # @param {type:'boolean'}\n", + "\n", "\n", "\n", - "# to do: make this not a closure.\n", - "def log_items(items):\n", - " #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify\n", - " recs = []\n", - " for item in items:\n", - " # assign image a filename\n", - " img_fname = f\"{RANDOM_PREFIX}_{SAMPLE_IDX}_{item['test_case']}.png\"\n", - " #rec = copy.deepcopy(item)\n", - " rec = item\n", - " img_fpath = Path(workspace_cfg.project_root) / img_fname\n", - " # save image\n", - " img = rec.pop('img')\n", - " save_im = False\n", - " if save_images or save_favorite_only:\n", - " save_im = True\n", - " if save_favorite_only and not rec['is_preference']:\n", - " save_im = False\n", - " if save_im:\n", - " print(img_fpath)\n", - " rec['img_fpath'] = str(img_fpath)\n", - " img.save(img_fpath)\n", - " # update outcome\n", - " rec['is_preference'] = rec['button'].value\n", - " if rec['is_preference']:\n", - " running_score[rec['test_case']] += 1\n", - " rec.pop('button')\n", - " # log outcome\n", - " recs.append(rec)\n", - " outfile = Path(workspace_cfg.project_root) / explog_fname\n", - " #with open(outfile, 'a') as f:\n", - " with outfile.open('a') as f:\n", - " json.dump(recs, f)\n", - " f.write('\\n')\n", - " logger.debug(running_score)\n", - " \n", - "if items:\n", - " try:\n", - " log_items(items)\n", - " posterior_plot(running_score)\n", - " except KeyError:\n", - " # fuck it\n", - " pass\n", "\n", "\n", "SEED = random.randrange(0, 4294967295)\n", "\n", - "blind_test = False # @param {type: \"boolean\"}\n", + "blind_test = True # @param {type: \"boolean\"}\n", "\n", "def item_to_ux(\n", " item\n", @@ -469,10 +423,11 @@ " output.append(toggle)\n", " item['button'] = toggle\n", " item['is_preference'] = toggle.value\n", + " print(item)\n", " return pn.Column(*output)\n", "\n", "\n", - "non_generation_arguments = ['grpc_host', 'engine', 'key']\n", + "non_generation_arguments = ['host', 'engine', 'key']\n", "\n", "rec = random.choice(experiments)\n", "\n", @@ -508,6 +463,10 @@ " if artifact.type == generation.ARTIFACT_IMAGE:\n", " img = Image.open(io.BytesIO(artifact.binary))\n", " img = img.resize([512, 512])\n", + " if save_images or save_favorite_only:\n", + " img_fname = f\"{kwargs_gen['prompt']}_{kwargs_gen['seed']}_{SAMPLE_IDX}_{test_case}.png\"\n", + " img.save(Copy_to_path+img_fname)\n", + "\n", " items.append({\n", " 'img':img,\n", " 'test_case':test_case,\n", @@ -517,18 +476,72 @@ " 'SDK_VERSION':SDK_VERSION,\n", " 'timestamp':time.time(),\n", " 'user_id': workspace_cfg.notebook_user_id,\n", - " 'project_name':workspace_cfg.active_project,\n", + " 'project_name':workspace_cfg.active_project, \n", " })\n", "\n", "random.shuffle(items)\n", - "pn.Row(*[item_to_ux(it) for it in items])" - ], - "metadata": { - "id": "ogPWQhWm7BXN", - "cellView": "form" - }, + "pn.Row(*[item_to_ux(it) for it in items]) " + ] + }, + { + "cell_type": "code", "execution_count": null, - "outputs": [] + "metadata": {}, + "outputs": [], + "source": [ + "# to do: make this not a closure.\n", + "import shutil\n", + "def log_items(items):\n", + " #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify\n", + " recs = []\n", + " for item in items:\n", + " # assign image a filename\n", + " # rec = copy.deepcopy(item)\n", + " rec = item\n", + " # save image\n", + " rec['is_preference'] = rec['button'].value\n", + " if save_favorite_only:\n", + " if rec['is_preference']:\n", + " img_fname1 = f\"{rec['kwargs']['prompt']}_{rec['kwargs']['seed']}_{SAMPLE_IDX}_{rec['test_case']}.png\"\n", + " print(type(Copy_to_path+img_fname1))\n", + " shutil.move(Copy_to_path+img_fname1, '/content/fav')\n", + " # update outcome\n", + " if rec['is_preference']:\n", + " running_score[rec['test_case']] += 1\n", + " rec.pop('img')\n", + " rec.pop('button')\n", + " # log outcome\n", + " recs.append(rec)\n", + " outfile = Path(workspace_cfg.project_root) / explog_fname\n", + " #with open(outfile, 'a') as f:\n", + " with outfile.open('a') as f:\n", + " json.dump(recs, f)\n", + " f.write('\\n')\n", + " logger.debug(running_score)\n", + "\n", + "# print(items)\n", + "if items:\n", + " try:\n", + " log_items(items)\n", + " posterior_plot(running_score)\n", + " except KeyError:\n", + " # fuck it\n", + " pass" + ] } - ] + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 }