diff --git a/tutorials/model_ensembling.ipynb b/tutorials/model_ensembling.ipynb index d42b4468..fe0fc2e6 100644 --- a/tutorials/model_ensembling.ipynb +++ b/tutorials/model_ensembling.ipynb @@ -44,7 +44,7 @@ "id": "8bd6262f", "metadata": {}, "source": [ - "** Note: ** Part of this notebook (in Section 4) uses the `scikit-learn` package. " + "**Note:** Part of this notebook (in Section 4) uses the `scikit-learn` package. " ] }, { @@ -65,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "89ef2804", "metadata": {}, "outputs": [ @@ -73,74 +73,65 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mWARNING: Skipping /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/docutils-0.18.1.dist-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Skipping /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/docutils-0.18.1.dist-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: seaborn in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (0.13.2)\n", - "Requirement already satisfied: scikit-learn in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (1.5.1)\n", - "Requirement already satisfied: cornac==2.2.2 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (2.2.2)\n", - "Requirement already satisfied: tensorflow==2.12.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (2.12.0)\n", - "Requirement already satisfied: numpy<2.0.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from cornac==2.2.2) (1.23.5)\n", - "Requirement already satisfied: scipy<=1.13.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from cornac==2.2.2) (1.10.1)\n", - "Requirement already satisfied: tqdm in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from cornac==2.2.2) (4.65.0)\n", - "Requirement already satisfied: powerlaw in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/powerlaw-1.5-py3.11.egg (from cornac==2.2.2) (1.5)\n", - "Requirement already satisfied: absl-py>=1.0.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (2.1.0)\n", - "Requirement already satisfied: astunparse>=1.6.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (1.6.3)\n", - "Requirement already satisfied: flatbuffers>=2.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (24.3.25)\n", - "Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (0.4.0)\n", - "Requirement already satisfied: google-pasta>=0.1.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (0.2.0)\n", - "Requirement already satisfied: h5py>=2.9.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (3.11.0)\n", - "Requirement already satisfied: jax>=0.3.15 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (0.4.30)\n", - "Requirement already satisfied: libclang>=13.0.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (18.1.1)\n", - "Requirement already satisfied: opt-einsum>=2.3.2 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (3.3.0)\n", - "Requirement already satisfied: packaging in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/packaging-23.1-py3.11.egg (from tensorflow==2.12.0) (23.1)\n", - "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (4.25.3)\n", - "Requirement already satisfied: setuptools in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (67.8.0)\n", - "Requirement already satisfied: six>=1.12.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/six-1.16.0-py3.11.egg (from tensorflow==2.12.0) (1.16.0)\n", - "Requirement already satisfied: termcolor>=1.1.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (2.4.0)\n", - "Requirement already satisfied: typing-extensions>=3.6.6 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (4.12.2)\n", - "Requirement already satisfied: wrapt<1.15,>=1.11.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (1.14.1)\n", - "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (1.65.1)\n", - "Requirement already satisfied: tensorboard<2.13,>=2.12 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (2.12.3)\n", - "Requirement already satisfied: tensorflow-estimator<2.13,>=2.12.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (2.12.0)\n", - "Requirement already satisfied: keras<2.13,>=2.12.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (2.12.0)\n", - "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorflow==2.12.0) (0.37.1)\n", - "Requirement already satisfied: pandas>=1.2 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from seaborn) (2.0.2)\n", - "Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/matplotlib-3.7.1-py3.11-macosx-10.9-x86_64.egg (from seaborn) (3.7.1)\n", - "Requirement already satisfied: joblib>=1.2.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from scikit-learn) (1.4.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from scikit-learn) (3.5.0)\n", - "Requirement already satisfied: wheel<1.0,>=0.23.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from astunparse>=1.6.0->tensorflow==2.12.0) (0.38.4)\n", - "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from jax>=0.3.15->tensorflow==2.12.0) (0.4.30)\n", - "Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from jax>=0.3.15->tensorflow==2.12.0) (0.3.2)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/contourpy-1.1.0-py3.11-macosx-10.9-x86_64.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (1.1.0)\n", - "Requirement already satisfied: cycler>=0.10 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/cycler-0.11.0-py3.11.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (0.11.0)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/fonttools-4.40.0-py3.11.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (4.40.0)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/kiwisolver-1.4.4-py3.11-macosx-10.9-x86_64.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.4)\n", - "Requirement already satisfied: pillow>=6.2.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/Pillow-9.5.0-py3.11-macosx-10.9-x86_64.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (9.5.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/pyparsing-3.1.0-py3.11.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (3.1.0)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/python_dateutil-2.8.2-py3.11.egg (from matplotlib!=3.6.1,>=3.4->seaborn) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from pandas>=1.2->seaborn) (2023.3)\n", - "Requirement already satisfied: tzdata>=2022.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from pandas>=1.2->seaborn) (2023.3)\n", - "Requirement already satisfied: google-auth<3,>=1.6.3 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.35.0)\n", - "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (1.0.0)\n", - "Requirement already satisfied: markdown>=2.6.8 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.6)\n", - "Requirement already satisfied: requests<3,>=2.21.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.31.0)\n", - "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (0.7.2)\n", - "Requirement already satisfied: werkzeug>=1.0.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.0.1)\n", - "Requirement already satisfied: mpmath in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from powerlaw->cornac==2.2.2) (1.3.0)\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (5.5.0)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (0.4.1)\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (4.9)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.0.0)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.0.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.4)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.0.2)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2023.5.7)\n", - "Requirement already satisfied: MarkupSafe>=2.1.1 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from werkzeug>=1.0.1->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.1.1)\n", - "Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (0.6.1)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.2.2)\n", - "\u001b[33mWARNING: Skipping /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/docutils-0.18.1.dist-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Skipping /Users/darrylong/miniconda3/envs/gcmc-pytorch/lib/python3.11/site-packages/docutils-0.18.1.dist-info due to invalid metadata entry 'name'\u001b[0m\u001b[33m\n", - "\u001b[0m" + "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.10/site-packages (1.5.1)\n", + "Requirement already satisfied: cornac==2.2.2 in /opt/conda/lib/python3.10/site-packages (2.2.2)\n", + "Requirement already satisfied: tensorflow==2.12.0 in /opt/conda/lib/python3.10/site-packages (2.12.0)\n", + "Requirement already satisfied: numpy<2.0.0 in /opt/conda/lib/python3.10/site-packages (from cornac==2.2.2) (1.23.5)\n", + "Requirement already satisfied: scipy<=1.13.1 in /opt/conda/lib/python3.10/site-packages (from cornac==2.2.2) (1.13.1)\n", + "Requirement already satisfied: tqdm in /opt/conda/lib/python3.10/site-packages (from cornac==2.2.2) (4.66.5)\n", + "Requirement already satisfied: powerlaw in /opt/conda/lib/python3.10/site-packages (from cornac==2.2.2) (1.5)\n", + "Requirement already satisfied: absl-py>=1.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (2.1.0)\n", + "Requirement already satisfied: astunparse>=1.6.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (1.6.3)\n", + "Requirement already satisfied: flatbuffers>=2.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (24.3.25)\n", + "Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (0.4.0)\n", + "Requirement already satisfied: google-pasta>=0.1.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (0.2.0)\n", + "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (1.66.2)\n", + "Requirement already satisfied: h5py>=2.9.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (3.12.1)\n", + "Requirement already satisfied: jax>=0.3.15 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (0.4.30)\n", + "Requirement already satisfied: keras<2.13,>=2.12.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (2.12.0)\n", + "Requirement already satisfied: libclang>=13.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (18.1.1)\n", + "Requirement already satisfied: opt-einsum>=2.3.2 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (3.4.0)\n", + "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (24.0)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (4.25.5)\n", + "Requirement already satisfied: setuptools in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (69.5.1)\n", + "Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (1.16.0)\n", + "Requirement already satisfied: tensorboard<2.13,>=2.12 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (2.12.3)\n", + "Requirement already satisfied: tensorflow-estimator<2.13,>=2.12.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (2.12.0)\n", + "Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (2.4.0)\n", + "Requirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (4.12.2)\n", + "Requirement already satisfied: wrapt<1.15,>=1.11.0 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (1.14.1)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /opt/conda/lib/python3.10/site-packages (from tensorflow==2.12.0) (0.37.1)\n", + "Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn) (3.5.0)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/conda/lib/python3.10/site-packages (from astunparse>=1.6.0->tensorflow==2.12.0) (0.43.0)\n", + "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /opt/conda/lib/python3.10/site-packages (from jax>=0.3.15->tensorflow==2.12.0) (0.4.30)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /opt/conda/lib/python3.10/site-packages (from jax>=0.3.15->tensorflow==2.12.0) (0.5.0)\n", + "Requirement already satisfied: google-auth<3,>=1.6.3 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.35.0)\n", + "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (1.0.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.7)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.32.3)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.0.4)\n", + "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (from powerlaw->cornac==2.2.2) (3.9.2)\n", + "Requirement already satisfied: mpmath in /opt/conda/lib/python3.10/site-packages (from powerlaw->cornac==2.2.2) (1.3.0)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (5.5.0)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (0.4.1)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (4.9)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.0.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.2.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2024.8.30)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/conda/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (2.1.5)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (1.3.0)\n", + "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (4.54.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (1.4.7)\n", + "Requirement already satisfied: pillow>=8 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (10.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (3.1.4)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib->powerlaw->cornac==2.2.2) (2.9.0)\n", + "Requirement already satisfied: pyasn1<0.7.0,>=0.4.6 in /opt/conda/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (0.6.1)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow==2.12.0) (3.2.2)\n" ] } ], @@ -150,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "bd266ee7", "metadata": {}, "outputs": [], @@ -185,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "92a57076", "metadata": {}, "outputs": [ @@ -193,7 +184,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "rating_threshold = 1.0\n", + "rating_threshold = 4.0\n", "exclude_unknowns = True\n", "---\n", "Training data:\n", @@ -219,7 +210,7 @@ "source": [ "data = movielens.load_feedback(variant=\"100K\") # Load MovieLens Dataset\n", "\n", - "rs = RatioSplit(data, test_size=0.2, seed=42, verbose=True) # Split to train-test set to 80-20\n", + "rs = RatioSplit(data, test_size=0.2, rating_threshold=4.0, seed=42, verbose=True) # Split to train-test set to 80-20\n", "train_set, test_set = rs.train_set, rs.test_set" ] }, @@ -238,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "ea466b90", "metadata": {}, "outputs": [ @@ -253,11 +244,18 @@ ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Ranking: 100%|██████████| 940/940 [00:00<00:00, 1701.10it/s]\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "37057803cd09410084cc147d15d58e4a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Ranking: 0%| | 0/940 [00:00\n", " \n", " Drama\n", - " 71\n", - " 24.9\n", + " 117\n", + " 22.6\n", " \n", " \n", " Comedy\n", - " 39\n", - " 13.7\n", + " 72\n", + " 13.9\n", " \n", " \n", " Romance\n", - " 32\n", - " 11.2\n", + " 56\n", + " 10.8\n", " \n", " \n", " Action\n", - " 30\n", - " 10.5\n", + " 55\n", + " 10.6\n", " \n", " \n", " Thriller\n", - " 29\n", - " 10.2\n", + " 50\n", + " 9.7\n", " \n", " \n", " Adventure\n", - " 19\n", - " 6.7\n", + " 36\n", + " 6.9\n", + " \n", + " \n", + " Children's\n", + " 23\n", + " 4.4\n", " \n", " \n", " War\n", - " 15\n", - " 5.3\n", + " 20\n", + " 3.9\n", " \n", " \n", " Crime\n", - " 12\n", - " 4.2\n", + " 20\n", + " 3.9\n", " \n", " \n", " Sci-Fi\n", - " 9\n", - " 3.2\n", - " \n", - " \n", - " Mystery\n", - " 8\n", - " 2.8\n", + " 18\n", + " 3.5\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Sum %\n", - "Drama 71 24.9\n", - "Comedy 39 13.7\n", - "Romance 32 11.2\n", - "Action 30 10.5\n", - "Thriller 29 10.2\n", - "Adventure 19 6.7\n", - "War 15 5.3\n", - "Crime 12 4.2\n", - "Sci-Fi 9 3.2\n", - "Mystery 8 2.8" + " Sum %\n", + "Drama 117 22.6\n", + "Comedy 72 13.9\n", + "Romance 56 10.8\n", + "Action 55 10.6\n", + "Thriller 50 9.7\n", + "Adventure 36 6.9\n", + "Children's 23 4.4\n", + "War 20 3.9\n", + "Crime 20 3.9\n", + "Sci-Fi 18 3.5" ] }, "metadata": {}, @@ -721,11 +726,13 @@ "UIDX = 3\n", "TOPK = 50\n", "\n", - "# Filter training data (rating = 5.0 and user index = UIDX)\n", - "filter_df = training_data_df[(training_data_df['rating'] == 5.0) & (training_data_df['user_idx'] == UIDX)]\n", - "filter_df = item_df.loc[[int(item_id) for item_id in filter_df[\"item_id\"]]] # get genres of movie items\n", + "# Positively rated items by a user (rating >= 4.0 as rating_threshold used earlier, and user index = UIDX)\n", + "positively_rated_items = training_data_df[\n", + " (training_data_df['rating'] >= 4.0) & (training_data_df['user_idx'] == UIDX)\n", + "]['item_id'].unique()\n", + "filter_df = item_df.loc[[int(item_id) for item_id in positively_rated_items]] # get genres of movie items\n", "\n", - "print(\"Number of movies:\", len(filter_df)) # Number of movies rated 5.0 by user index 3 in training data\n", + "print(\"Number of movies:\", len(filter_df)) # Number of movies positvely rated by user index 3 in training data\n", "\n", "# Group by Movie Genre and Sum by genres\n", "filter_df = filter_df.select_dtypes(np.number).sum() \n", @@ -736,7 +743,7 @@ "filter_df[\"%\"] = filter_df[\"%\"].round(1)\n", "\n", "# Let's see the training data genres, sums and percentages\n", - "print(\"Movies rated 5.0 by user index 3 in training data\")\n", + "print(\"Positively rated movies by user index 3 in training data\")\n", "display(filter_df.sort_values(\"Sum\", ascending=False)[:10])" ] }, @@ -745,7 +752,7 @@ "id": "d700c1c1", "metadata": {}, "source": [ - "As shown above in the training data, the top genres for user index 3 with movies rated 5.0 include 'Drama', 'Comedy', 'Romance', 'Action' and 'Thriller'.\n", + "As shown above in the training data, the top genres for user index 3 with positively rated movies include 'Drama', 'Comedy', 'Romance', 'Action' and 'Thriller'.\n", "\n", "Let's now compare them to the recommendations of the BPR and WMF models respectively." ] @@ -760,7 +767,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "id": "72759171", "metadata": {}, "outputs": [ @@ -769,8 +776,7 @@ "output_type": "stream", "text": [ "\n", - "Top 5 Genres in training data: ['Drama', 'Comedy', 'Romance', 'Action', 'Thriller']\n", - "50\n" + "Top 5 Genres in training data: ['Drama', 'Comedy', 'Romance', 'Action', 'Thriller']\n" ] }, { @@ -1013,20 +1019,20 @@ " 0\n", " \n", " \n", - " 125\n", - " Phenomenon (1996)\n", + " 8\n", + " Babe (1995)\n", " 1\n", - " 0\n", " 1\n", " 0\n", " 0\n", + " 0\n", " \n", " \n", - " 8\n", - " Babe (1995)\n", - " 1\n", + " 125\n", + " Phenomenon (1996)\n", " 1\n", " 0\n", + " 1\n", " 0\n", " 0\n", " \n", @@ -1093,8 +1099,8 @@ "ItemID \n", "313 Titanic (1997) 1 0 1 1 \n", "204 Back to the Future (1985) 0 1 0 0 \n", - "125 Phenomenon (1996) 1 0 1 0 \n", "8 Babe (1995) 1 1 0 0 \n", + "125 Phenomenon (1996) 1 0 1 0 \n", "318 Schindler's List (1993) 1 0 0 0 \n", "15 Mr. Holland's Opus (1995) 1 0 0 0 \n", "64 Shawshank Redemption, The (1994) 1 0 0 0 \n", @@ -1106,8 +1112,8 @@ "ItemID \n", "313 0 \n", "204 0 \n", - "125 0 \n", "8 0 \n", + "125 0 \n", "318 0 \n", "15 0 \n", "64 0 \n", @@ -1128,7 +1134,6 @@ "# Get top K recommendations for BPR and put them into the genre dataframe\n", "bpr_recommendations, bpr_scores = bpr_model.rank(UIDX) # rank recommendations by score, limit to top K\n", "bpr_recommendations = bpr_recommendations[:TOPK] # limit to top K\n", - "print(len(bpr_recommendations))\n", "bpr_topk = [item_idx2id[iidx] for iidx in bpr_recommendations] # convert item indexes into item ids\n", "bpr_df = item_df.loc[[int(iid) for iid in bpr_topk]] # filter the movie genre dataframe by item ids\n", "\n", @@ -1163,7 +1168,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "id": "283ca840", "metadata": {}, "outputs": [ @@ -1205,80 +1210,80 @@ " \n", " \n", " Drama\n", - " 24.9\n", + " 22.6\n", " 17.1\n", - " 25.5\n", + " 26.0\n", " \n", " \n", " Comedy\n", + " 13.9\n", " 13.7\n", - " 13.7\n", - " 20.6\n", + " 21.0\n", " \n", " \n", " Romance\n", - " 11.2\n", + " 10.8\n", " 17.9\n", - " 17.6\n", + " 18.0\n", " \n", " \n", " Action\n", - " 10.5\n", + " 10.6\n", " 13.7\n", - " 7.8\n", + " 7.0\n", " \n", " \n", " Thriller\n", - " 10.2\n", + " 9.7\n", " 12.8\n", - " 6.9\n", + " 6.0\n", " \n", " \n", " Adventure\n", - " 6.7\n", + " 6.9\n", " 6.8\n", - " 4.9\n", + " 5.0\n", + " \n", + " \n", + " Children's\n", + " 4.4\n", + " 1.7\n", + " 2.0\n", " \n", " \n", " War\n", - " 5.3\n", + " 3.9\n", " 5.1\n", - " 4.9\n", + " 5.0\n", " \n", " \n", " Crime\n", - " 4.2\n", + " 3.9\n", " 2.6\n", " 1.0\n", " \n", " \n", " Sci-Fi\n", - " 3.2\n", + " 3.5\n", " 3.4\n", - " 3.9\n", - " \n", - " \n", - " Mystery\n", - " 2.8\n", - " 1.7\n", - " 2.0\n", + " 4.0\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Train Data % BPR % WMF %\n", - "Drama 24.9 17.1 25.5\n", - "Comedy 13.7 13.7 20.6\n", - "Romance 11.2 17.9 17.6\n", - "Action 10.5 13.7 7.8\n", - "Thriller 10.2 12.8 6.9\n", - "Adventure 6.7 6.8 4.9\n", - "War 5.3 5.1 4.9\n", - "Crime 4.2 2.6 1.0\n", - "Sci-Fi 3.2 3.4 3.9\n", - "Mystery 2.8 1.7 2.0" + " Train Data % BPR % WMF %\n", + "Drama 22.6 17.1 26.0\n", + "Comedy 13.9 13.7 21.0\n", + "Romance 10.8 17.9 18.0\n", + "Action 10.6 13.7 7.0\n", + "Thriller 9.7 12.8 6.0\n", + "Adventure 6.9 6.8 5.0\n", + "Children's 4.4 1.7 2.0\n", + "War 3.9 5.1 5.0\n", + "Crime 3.9 2.6 1.0\n", + "Sci-Fi 3.5 3.4 4.0" ] }, "metadata": {}, @@ -1310,7 +1315,7 @@ "id": "c30fe92b", "metadata": {}, "source": [ - "Note that many movies have different genres, so the sum of the genre counts may exceed the total number of recommendations.\n", + "Note that many movies belong to multiple genres, so the sum of the genre counts may exceed the total number of recommendations.\n", "\n", "-------\n", "\n", @@ -1364,7 +1369,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "id": "b349407b", "metadata": {}, "outputs": [ @@ -1517,7 +1522,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "id": "99cdd112", "metadata": {}, "outputs": [ @@ -1579,7 +1584,7 @@ " 4.515\n", " 7\n", " 1644\n", - " 5.058\n", + " 5.057\n", " 11\n", " 1640\n", " 3284\n", @@ -1624,7 +1629,7 @@ "text/plain": [ " ItemID BPR Score BPR Rank BPR Points WMF Score WMF Rank WMF Points \\\n", "152 313 4.494 8 1643 6.066 1 1650 \n", - "194 739 4.515 7 1644 5.058 11 1640 \n", + "194 739 4.515 7 1644 5.057 11 1640 \n", "425 237 4.196 15 1636 4.968 18 1633 \n", "310 692 3.989 26 1625 5.142 9 1642 \n", "382 655 3.979 27 1624 5.165 8 1643 \n", @@ -1665,7 +1670,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 12, "id": "ac86f568", "metadata": {}, "outputs": [ @@ -1708,90 +1713,90 @@ " \n", " \n", " Drama\n", - " 24.9\n", + " 22.6\n", " 17.1\n", - " 25.5\n", + " 26.0\n", " 17.2\n", " \n", " \n", " Comedy\n", + " 13.9\n", " 13.7\n", - " 13.7\n", - " 20.6\n", + " 21.0\n", " 19.0\n", " \n", " \n", " Romance\n", - " 11.2\n", + " 10.8\n", " 17.9\n", - " 17.6\n", + " 18.0\n", " 19.0\n", " \n", " \n", " Action\n", - " 10.5\n", + " 10.6\n", " 13.7\n", - " 7.8\n", + " 7.0\n", " 10.3\n", " \n", " \n", " Thriller\n", - " 10.2\n", + " 9.7\n", " 12.8\n", - " 6.9\n", + " 6.0\n", " 7.8\n", " \n", " \n", " Adventure\n", - " 6.7\n", + " 6.9\n", " 6.8\n", - " 4.9\n", + " 5.0\n", " 6.9\n", " \n", " \n", + " Children's\n", + " 4.4\n", + " 1.7\n", + " 2.0\n", + " 1.7\n", + " \n", + " \n", " War\n", - " 5.3\n", + " 3.9\n", " 5.1\n", - " 4.9\n", + " 5.0\n", " 6.0\n", " \n", " \n", " Crime\n", - " 4.2\n", + " 3.9\n", " 2.6\n", " 1.0\n", " 1.7\n", " \n", " \n", " Sci-Fi\n", - " 3.2\n", + " 3.5\n", " 3.4\n", - " 3.9\n", + " 4.0\n", " 6.9\n", " \n", - " \n", - " Mystery\n", - " 2.8\n", - " 1.7\n", - " 2.0\n", - " 1.7\n", - " \n", " \n", "\n", "" ], "text/plain": [ - " Train Data % BPR % WMF % Borda Count %\n", - "Drama 24.9 17.1 25.5 17.2\n", - "Comedy 13.7 13.7 20.6 19.0\n", - "Romance 11.2 17.9 17.6 19.0\n", - "Action 10.5 13.7 7.8 10.3\n", - "Thriller 10.2 12.8 6.9 7.8\n", - "Adventure 6.7 6.8 4.9 6.9\n", - "War 5.3 5.1 4.9 6.0\n", - "Crime 4.2 2.6 1.0 1.7\n", - "Sci-Fi 3.2 3.4 3.9 6.9\n", - "Mystery 2.8 1.7 2.0 1.7" + " Train Data % BPR % WMF % Borda Count %\n", + "Drama 22.6 17.1 26.0 17.2\n", + "Comedy 13.9 13.7 21.0 19.0\n", + "Romance 10.8 17.9 18.0 19.0\n", + "Action 10.6 13.7 7.0 10.3\n", + "Thriller 9.7 12.8 6.0 7.8\n", + "Adventure 6.9 6.8 5.0 6.9\n", + "Children's 4.4 1.7 2.0 1.7\n", + "War 3.9 5.1 5.0 6.0\n", + "Crime 3.9 2.6 1.0 1.7\n", + "Sci-Fi 3.5 3.4 4.0 6.9" ] }, "metadata": {}, @@ -1843,7 +1848,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 13, "id": "5ce879a6", "metadata": {}, "outputs": [ @@ -1856,11 +1861,18 @@ ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 300/300 [00:22<00:00, 13.57it/s, loss=173]\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "530f726f03454f569ac07d5601426d77", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/300 [00:00\n", " \n", " \n", + " ItemID\n", " WMF Borda Count\n", " \n", " \n", " \n", " \n", " 37\n", - " 14758\n", + " 318\n", + " 14757\n", " \n", " \n", " 152\n", - " 14708\n", + " 313\n", + " 14710\n", " \n", " \n", " 197\n", - " 14660\n", + " 191\n", + " 14663\n", " \n", " \n", " 132\n", + " 272\n", " 14633\n", " \n", " \n", " 156\n", - " 14632\n", + " 64\n", + " 14629\n", " \n", " \n", " 61\n", - " 14603\n", + " 204\n", + " 14605\n", " \n", " \n", " 279\n", - " 14598\n", + " 402\n", + " 14604\n", " \n", " \n", " 305\n", - " 14582\n", + " 181\n", + " 14585\n", " \n", " \n", " 405\n", + " 22\n", " 14581\n", " \n", " \n", " 604\n", - " 14541\n", + " 215\n", + " 14542\n", " \n", " \n", "\n", "" ], "text/plain": [ - " WMF Borda Count\n", - "37 14758\n", - "152 14708\n", - "197 14660\n", - "132 14633\n", - "156 14632\n", - "61 14603\n", - "279 14598\n", - "305 14582\n", - "405 14581\n", - "604 14541" + " ItemID WMF Borda Count\n", + "37 318 14757\n", + "152 313 14710\n", + "197 191 14663\n", + "132 272 14633\n", + "156 64 14629\n", + "61 204 14605\n", + "279 402 14604\n", + "305 181 14585\n", + "405 22 14581\n", + "604 215 14542" ] }, "metadata": {}, @@ -2314,21 +2435,18 @@ "for model in models:\n", " name = model.name\n", " recommendations, scores = model.rank(UIDX)\n", - " rank_2_df[name + \"_rating\"] = scores\n", - " rank_2_df[name + \"_rank\"] = rank_2_df[name + \"_rating\"].rank(ascending=False).astype(int)\n", + " rank_2_df[name + \"_score\"] = scores\n", + " rank_2_df[name + \"_rank\"] = rank_2_df[name + \"_score\"].rank(ascending=False).astype(int)\n", " rank_2_df[name + \"_points\"] = total_items - rank_2_df[name + \"_rank\"]\n", " rank_2_df[\"WMF Borda Count\"] = rank_2_df[\"WMF Borda Count\"] + rank_2_df[name + \"_points\"]\n", "\n", - "# Round results for readability\n", - "rank_2_df = rank_2_df.round(3)\n", - "\n", "# Let's sort and view the top recommendations!\n", - "display(\"Top 10 Recommendations for WMF Borda Count\", rank_2_df[[\"WMF Borda Count\"]].sort_values(\"WMF Borda Count\", ascending=False).head(10))" + "display(\"Top 10 Recommendations for WMF Borda Count\", rank_2_df[[\"ItemID\", \"WMF Borda Count\"]].sort_values(\"WMF Borda Count\", ascending=False).head(10))" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 15, "id": "8224e10e", "metadata": {}, "outputs": [ @@ -2370,80 +2488,80 @@ " \n", " \n", " Drama\n", - " 24.9\n", - " 25.5\n", + " 22.6\n", + " 26.0\n", " 22.7\n", " \n", " \n", " Comedy\n", - " 13.7\n", - " 20.6\n", + " 13.9\n", + " 21.0\n", " 14.3\n", " \n", " \n", " Romance\n", - " 11.2\n", - " 17.6\n", + " 10.8\n", + " 18.0\n", " 16.8\n", " \n", " \n", " Action\n", - " 10.5\n", - " 7.8\n", + " 10.6\n", + " 7.0\n", " 9.2\n", " \n", " \n", " Thriller\n", - " 10.2\n", - " 6.9\n", + " 9.7\n", + " 6.0\n", " 5.9\n", " \n", " \n", " Adventure\n", - " 6.7\n", - " 4.9\n", + " 6.9\n", + " 5.0\n", " 5.9\n", " \n", " \n", + " Children's\n", + " 4.4\n", + " 2.0\n", + " 3.4\n", + " \n", + " \n", " War\n", - " 5.3\n", - " 4.9\n", + " 3.9\n", + " 5.0\n", " 6.7\n", " \n", " \n", " Crime\n", - " 4.2\n", + " 3.9\n", " 1.0\n", " 3.4\n", " \n", " \n", " Sci-Fi\n", - " 3.2\n", - " 3.9\n", + " 3.5\n", + " 4.0\n", " 5.0\n", " \n", - " \n", - " Mystery\n", - " 2.8\n", - " 2.0\n", - " 2.5\n", - " \n", " \n", "\n", "" ], "text/plain": [ - " Train Data % WMF % WMF Borda Count %\n", - "Drama 24.9 25.5 22.7\n", - "Comedy 13.7 20.6 14.3\n", - "Romance 11.2 17.6 16.8\n", - "Action 10.5 7.8 9.2\n", - "Thriller 10.2 6.9 5.9\n", - "Adventure 6.7 4.9 5.9\n", - "War 5.3 4.9 6.7\n", - "Crime 4.2 1.0 3.4\n", - "Sci-Fi 3.2 3.9 5.0\n", - "Mystery 2.8 2.0 2.5" + " Train Data % WMF % WMF Borda Count %\n", + "Drama 22.6 26.0 22.7\n", + "Comedy 13.9 21.0 14.3\n", + "Romance 10.8 18.0 16.8\n", + "Action 10.6 7.0 9.2\n", + "Thriller 9.7 6.0 5.9\n", + "Adventure 6.9 5.0 5.9\n", + "Children's 4.4 2.0 3.4\n", + "War 3.9 5.0 6.7\n", + "Crime 3.9 1.0 3.4\n", + "Sci-Fi 3.5 4.0 5.0" ] }, "metadata": {}, @@ -2452,7 +2570,7 @@ ], "source": [ "# Now, let's add them to the combined dataframe for comparison with earlier models\n", - "wmf_borda_count_topk = list(rank_2_df.sort_values(\"WMF Borda Count\", ascending=False)[\"ItemID\"].values[:TOPK])\n", + "wmf_borda_count_topk = rank_2_df.sort_values(\"WMF Borda Count\", ascending=False)[\"ItemID\"].values[:TOPK]\n", "wmf_borda_df = item_df.loc[[int(i) for i in wmf_borda_count_topk]]\n", "\n", "combined_df[\"WMF Borda Count Sum\"] = wmf_borda_df.select_dtypes(np.number).sum()\n", @@ -2507,7 +2625,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 16, "id": "380223e2", "metadata": {}, "outputs": [ @@ -2515,7 +2633,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 9/9 [02:33<00:00, 17.11s/it]\n" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:13<00:00, 14.81s/it]\n" ] }, { @@ -2562,39 +2680,39 @@ " \n", " \n", " 0\n", - " 2.109998\n", - " 2.071429\n", - " 1.903534\n", - " 2.302128\n", - " 3.117219\n", - " 2.806823\n", - " 3.367875\n", - " 4.242294\n", - " 3.840917\n", + " 2.110029\n", + " 2.071523\n", + " 1.903551\n", + " 2.302063\n", + " 3.117306\n", + " 2.806793\n", + " 3.367846\n", + " 4.248022\n", + " 3.843781\n", " \n", " \n", " 1\n", - " 2.791594\n", - " 2.692412\n", - " 2.421495\n", - " 2.478934\n", - " 2.736730\n", - " 2.779788\n", - " 2.640383\n", - " 2.263831\n", - " 2.303503\n", + " 2.791619\n", + " 2.692337\n", + " 2.421384\n", + " 2.478971\n", + " 2.736729\n", + " 2.779977\n", + " 2.639981\n", + " 2.263646\n", + " 2.319107\n", " \n", " \n", " 2\n", - " 3.750995\n", - " 3.385061\n", - " 3.541965\n", - " 3.757422\n", - " 3.728073\n", - " 4.114565\n", - " 3.430787\n", - " 3.497283\n", - " 3.115823\n", + " 3.751000\n", + " 3.385016\n", + " 3.542090\n", + " 3.759172\n", + " 3.727951\n", + " 4.115041\n", + " 3.442950\n", + " 3.486503\n", + " 3.089492\n", " \n", " \n", "\n", @@ -2602,14 +2720,14 @@ ], "text/plain": [ " WMF_123_score WMF_456_score WMF_789_score WMF_888_score WMF_999_score \\\n", - "0 2.109998 2.071429 1.903534 2.302128 3.117219 \n", - "1 2.791594 2.692412 2.421495 2.478934 2.736730 \n", - "2 3.750995 3.385061 3.541965 3.757422 3.728073 \n", + "0 2.110029 2.071523 1.903551 2.302063 3.117306 \n", + "1 2.791619 2.692337 2.421384 2.478971 2.736729 \n", + "2 3.751000 3.385016 3.542090 3.759172 3.727951 \n", "\n", " WMF_k20_score WMF_k30_score WMF_k40_score WMF_k50_score \n", - "0 2.806823 3.367875 4.242294 3.840917 \n", - "1 2.779788 2.640383 2.263831 2.303503 \n", - "2 4.114565 3.430787 3.497283 3.115823 " + "0 2.806793 3.367846 4.248022 3.843781 \n", + "1 2.779977 2.639981 2.263646 2.319107 \n", + "2 4.115041 3.442950 3.486503 3.089492 " ] }, "metadata": {}, @@ -2630,7 +2748,7 @@ "0 4.0\n", "1 3.0\n", "2 4.0\n", - "Name: ground_score, dtype: float64" + "Name: rating, dtype: float64" ] }, "metadata": {}, @@ -2680,39 +2798,39 @@ " \n", " \n", " 0\n", - " 2.109998\n", - " 2.071429\n", - " 1.903534\n", - " 2.302128\n", - " 3.117219\n", - " 2.806823\n", - " 3.367875\n", - " 4.242294\n", - " 3.840917\n", + " 2.110029\n", + " 2.071523\n", + " 1.903551\n", + " 2.302063\n", + " 3.117306\n", + " 2.806793\n", + " 3.367846\n", + " 4.248022\n", + " 3.843781\n", " \n", " \n", " 1\n", - " 0.807329\n", - " 1.295570\n", - " 0.918373\n", - " 0.553849\n", - " 0.588498\n", - " -0.075745\n", - " -0.366928\n", - " -0.917656\n", - " 0.786125\n", + " 0.807322\n", + " 1.295580\n", + " 0.918263\n", + " 0.553786\n", + " 0.588544\n", + " -0.075428\n", + " -0.366332\n", + " -0.918094\n", + " 0.795499\n", " \n", " \n", " 2\n", - " 1.648487\n", - " 1.456568\n", - " 1.591698\n", - " 1.270859\n", - " 1.677644\n", - " 2.485114\n", - " 2.034434\n", - " 2.333905\n", - " 0.949877\n", + " 1.648435\n", + " 1.456549\n", + " 1.591913\n", + " 1.271597\n", + " 1.677710\n", + " 2.486352\n", + " 2.044948\n", + " 2.325095\n", + " 0.936491\n", " \n", " \n", "\n", @@ -2720,14 +2838,14 @@ ], "text/plain": [ " WMF_123_score WMF_456_score WMF_789_score WMF_888_score WMF_999_score \\\n", - "0 2.109998 2.071429 1.903534 2.302128 3.117219 \n", - "1 0.807329 1.295570 0.918373 0.553849 0.588498 \n", - "2 1.648487 1.456568 1.591698 1.270859 1.677644 \n", + "0 2.110029 2.071523 1.903551 2.302063 3.117306 \n", + "1 0.807322 1.295580 0.918263 0.553786 0.588544 \n", + "2 1.648435 1.456549 1.591913 1.271597 1.677710 \n", "\n", " WMF_k20_score WMF_k30_score WMF_k40_score WMF_k50_score \n", - "0 2.806823 3.367875 4.242294 3.840917 \n", - "1 -0.075745 -0.366928 -0.917656 0.786125 \n", - "2 2.485114 2.034434 2.333905 0.949877 " + "0 2.806793 3.367846 4.248022 3.843781 \n", + "1 -0.075428 -0.366332 -0.918094 0.795499 \n", + "2 2.486352 2.044948 2.325095 0.936491 " ] }, "metadata": {}, @@ -2737,7 +2855,7 @@ "source": [ "# First, lets create training and test data dataframes\n", "training_df = pd.DataFrame(zip(*train_set.uir_tuple)) # Add 'User Index', 'Item Index', 'Rating' triples as records in dataframe\n", - "training_df.columns = ['user_idx', 'item_idx', 'ground_score'] # Set column names\n", + "training_df.columns = ['user_idx', 'item_idx', 'rating'] # Set column names\n", "\n", "# Get all possible user_index, item_index combinations, add them into dataframe for inference\n", "all_df = pd.DataFrame({\n", @@ -2747,6 +2865,7 @@ "all_df['item_id'] = all_df.apply(lambda row: item_idx2id[int(row['item_idx'])], axis=1) # Add 'Item ID' column into dataframe by converting 'Item Index' to 'Item ID'\n", "\n", "# Lets get all the scores for the models trained in Part 3.\n", + "models = [wmf_model_123, wmf_model_456, wmf_model_789, wmf_model_888, wmf_model_999, wmf_model_k20, wmf_model_k30, wmf_model_k40, wmf_model_k50]\n", "\n", "# For each model, we add individual predicted ratings by individual models to training and test dataframes\n", "for model in tqdm(models):\n", @@ -2756,7 +2875,7 @@ "\n", "# Let's pick out the 5 features - predicted ratings from the 5 models trained\n", "X_train = training_df[['WMF_123_score', 'WMF_456_score', 'WMF_789_score', 'WMF_888_score', 'WMF_999_score', 'WMF_k20_score', 'WMF_k30_score', 'WMF_k40_score', 'WMF_k50_score']] # use these predicted ratings as features\n", - "y_train = training_df['ground_score'] # use ground truth to train this linear regression model\n", + "y_train = training_df['rating'] # use ground truth to train this linear regression model\n", "X_inference = all_df[['WMF_123_score', 'WMF_456_score', 'WMF_789_score', 'WMF_888_score', 'WMF_999_score', 'WMF_k20_score', 'WMF_k30_score', 'WMF_k40_score', 'WMF_k50_score']] # all data, used to predict values for ranking\n", "\n", "display(\"Training features\", X_train.head(3)) # predicting ratings as features\n", @@ -2782,7 +2901,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 17, "id": "16a564bd", "metadata": {}, "outputs": [ @@ -2791,8 +2910,8 @@ "output_type": "stream", "text": [ "Coefficients of the linear regression model\n", - "[-0.03534374 0.05045098 -0.09062641 -0.05270565 -0.12603936 -0.1051683\n", - " 0.13939755 0.48237574 0.8496893 ]\n", + "[-0.03553173 0.05021945 -0.09029353 -0.052304 -0.12588827 -0.10598847\n", + " 0.13995151 0.48300126 0.84912807]\n", "0.0\n" ] }, @@ -2835,90 +2954,90 @@ " \n", " \n", " Drama\n", - " 24.9\n", - " 25.5\n", + " 22.6\n", + " 26.0\n", " 22.7\n", " 26.0\n", " \n", " \n", " Comedy\n", - " 13.7\n", - " 20.6\n", + " 13.9\n", + " 21.0\n", " 14.3\n", " 11.5\n", " \n", " \n", " Romance\n", - " 11.2\n", - " 17.6\n", + " 10.8\n", + " 18.0\n", " 16.8\n", " 11.5\n", " \n", " \n", " Action\n", - " 10.5\n", - " 7.8\n", + " 10.6\n", + " 7.0\n", " 9.2\n", " 12.5\n", " \n", " \n", " Thriller\n", - " 10.2\n", - " 6.9\n", + " 9.7\n", + " 6.0\n", " 5.9\n", " 11.5\n", " \n", " \n", " Adventure\n", - " 6.7\n", - " 4.9\n", + " 6.9\n", + " 5.0\n", " 5.9\n", " 8.7\n", " \n", " \n", + " Children's\n", + " 4.4\n", + " 2.0\n", + " 3.4\n", + " 1.9\n", + " \n", + " \n", " War\n", - " 5.3\n", - " 4.9\n", + " 3.9\n", + " 5.0\n", " 6.7\n", " 3.8\n", " \n", " \n", " Crime\n", - " 4.2\n", + " 3.9\n", " 1.0\n", " 3.4\n", " 3.8\n", " \n", " \n", " Sci-Fi\n", - " 3.2\n", - " 3.9\n", + " 3.5\n", + " 4.0\n", " 5.0\n", " 2.9\n", " \n", - " \n", - " Mystery\n", - " 2.8\n", - " 2.0\n", - " 2.5\n", - " 4.8\n", - " \n", " \n", "\n", "" ], "text/plain": [ - " Train Data % WMF % WMF Borda Count % WMF Linear Regression %\n", - "Drama 24.9 25.5 22.7 26.0\n", - "Comedy 13.7 20.6 14.3 11.5\n", - "Romance 11.2 17.6 16.8 11.5\n", - "Action 10.5 7.8 9.2 12.5\n", - "Thriller 10.2 6.9 5.9 11.5\n", - "Adventure 6.7 4.9 5.9 8.7\n", - "War 5.3 4.9 6.7 3.8\n", - "Crime 4.2 1.0 3.4 3.8\n", - "Sci-Fi 3.2 3.9 5.0 2.9\n", - "Mystery 2.8 2.0 2.5 4.8" + " Train Data % WMF % WMF Borda Count % WMF Linear Regression %\n", + "Drama 22.6 26.0 22.7 26.0\n", + "Comedy 13.9 21.0 14.3 11.5\n", + "Romance 10.8 18.0 16.8 11.5\n", + "Action 10.6 7.0 9.2 12.5\n", + "Thriller 9.7 6.0 5.9 11.5\n", + "Adventure 6.9 5.0 5.9 8.7\n", + "Children's 4.4 2.0 3.4 1.9\n", + "War 3.9 5.0 6.7 3.8\n", + "Crime 3.9 1.0 3.4 3.8\n", + "Sci-Fi 3.5 4.0 5.0 2.9" ] }, "metadata": {}, @@ -2939,8 +3058,8 @@ "all_df[\"WMF Linear Regression\"] = y_pred # create a column in `test_df` for the predictions\n", "\n", "# Get Top K ratings from predictions\n", - "all_df = all_df.sort_values(\"WMF Linear Regression\", ascending=False) # sort by predicted ratings\n", - "top_item_ids = all_df[all_df['user_idx'] == UIDX]['item_id'].values[:TOPK] # filter top K (50 as set in Section 2.3)\n", + "sorted_df = all_df.sort_values(\"WMF Linear Regression\", ascending=False) # sort by predicted ratings\n", + "top_item_ids = sorted_df[sorted_df['user_idx'] == UIDX]['item_id'].values[:TOPK] # filter top K (50 as set in Section 2.3)\n", "\n", "# Place them into the comparison distribution dataframe\n", "linear_regression_df = item_df.loc[[int(i) for i in top_item_ids]] # Get genres of ratings\n", @@ -2983,7 +3102,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 18, "id": "0fe095ce", "metadata": {}, "outputs": [ @@ -3027,112 +3146,112 @@ " \n", " \n", " Drama\n", - " 24.9\n", - " 25.5\n", + " 22.6\n", + " 26.0\n", " 22.7\n", " 26.0\n", - " 26.5\n", + " 26.4\n", " \n", " \n", " Comedy\n", - " 13.7\n", - " 20.6\n", + " 13.9\n", + " 21.0\n", " 14.3\n", " 11.5\n", - " 16.9\n", + " 12.3\n", " \n", " \n", " Romance\n", - " 11.2\n", - " 17.6\n", + " 10.8\n", + " 18.0\n", " 16.8\n", " 11.5\n", - " 9.6\n", + " 13.2\n", " \n", " \n", " Action\n", - " 10.5\n", - " 7.8\n", + " 10.6\n", + " 7.0\n", " 9.2\n", " 12.5\n", - " 8.4\n", + " 10.4\n", " \n", " \n", " Thriller\n", - " 10.2\n", - " 6.9\n", + " 9.7\n", + " 6.0\n", " 5.9\n", " 11.5\n", - " 4.8\n", + " 9.4\n", " \n", " \n", " Adventure\n", - " 6.7\n", - " 4.9\n", + " 6.9\n", + " 5.0\n", " 5.9\n", " 8.7\n", - " 2.4\n", + " 7.5\n", + " \n", + " \n", + " Children's\n", + " 4.4\n", + " 2.0\n", + " 3.4\n", + " 1.9\n", + " 0.9\n", " \n", " \n", " War\n", - " 5.3\n", - " 4.9\n", + " 3.9\n", + " 5.0\n", " 6.7\n", " 3.8\n", - " 2.4\n", + " 3.8\n", " \n", " \n", " Crime\n", - " 4.2\n", + " 3.9\n", " 1.0\n", " 3.4\n", " 3.8\n", - " 6.0\n", + " 5.7\n", " \n", " \n", " Sci-Fi\n", - " 3.2\n", - " 3.9\n", + " 3.5\n", + " 4.0\n", " 5.0\n", " 2.9\n", - " 4.8\n", - " \n", - " \n", - " Mystery\n", - " 2.8\n", - " 2.0\n", - " 2.5\n", - " 4.8\n", - " 0.0\n", + " 3.8\n", " \n", " \n", "\n", "" ], "text/plain": [ - " Train Data % WMF % WMF Borda Count % WMF Linear Regression % \\\n", - "Drama 24.9 25.5 22.7 26.0 \n", - "Comedy 13.7 20.6 14.3 11.5 \n", - "Romance 11.2 17.6 16.8 11.5 \n", - "Action 10.5 7.8 9.2 12.5 \n", - "Thriller 10.2 6.9 5.9 11.5 \n", - "Adventure 6.7 4.9 5.9 8.7 \n", - "War 5.3 4.9 6.7 3.8 \n", - "Crime 4.2 1.0 3.4 3.8 \n", - "Sci-Fi 3.2 3.9 5.0 2.9 \n", - "Mystery 2.8 2.0 2.5 4.8 \n", + " Train Data % WMF % WMF Borda Count % WMF Linear Regression % \\\n", + "Drama 22.6 26.0 22.7 26.0 \n", + "Comedy 13.9 21.0 14.3 11.5 \n", + "Romance 10.8 18.0 16.8 11.5 \n", + "Action 10.6 7.0 9.2 12.5 \n", + "Thriller 9.7 6.0 5.9 11.5 \n", + "Adventure 6.9 5.0 5.9 8.7 \n", + "Children's 4.4 2.0 3.4 1.9 \n", + "War 3.9 5.0 6.7 3.8 \n", + "Crime 3.9 1.0 3.4 3.8 \n", + "Sci-Fi 3.5 4.0 5.0 2.9 \n", "\n", - " WMF Random Forest % \n", - "Drama 26.5 \n", - "Comedy 16.9 \n", - "Romance 9.6 \n", - "Action 8.4 \n", - "Thriller 4.8 \n", - "Adventure 2.4 \n", - "War 2.4 \n", - "Crime 6.0 \n", - "Sci-Fi 4.8 \n", - "Mystery 0.0 " + " WMF Random Forest % \n", + "Drama 26.4 \n", + "Comedy 12.3 \n", + "Romance 13.2 \n", + "Action 10.4 \n", + "Thriller 9.4 \n", + "Adventure 7.5 \n", + "Children's 0.9 \n", + "War 3.8 \n", + "Crime 5.7 \n", + "Sci-Fi 3.8 " ] }, "metadata": {}, @@ -3153,8 +3272,8 @@ "all_df[\"WMF Random Forest\"] = y_pred # create a column in `all_df` for the predictions\n", "\n", "# Get Top K ratings from predictions\n", - "all_df = all_df.sort_values(\"WMF Random Forest\", ascending=False) # sort by predicted ratings\n", - "top_item_ids = all_df[all_df['user_idx'] == UIDX]['item_id'].values[:TOPK] # filter top K (50 as set in Section 2.3)\n", + "sorted_df = all_df.sort_values(\"WMF Random Forest\", ascending=False) # sort by predicted ratings\n", + "top_item_ids = sorted_df[sorted_df['user_idx'] == UIDX]['item_id'].values[:TOPK] # filter top K (50 as set in Section 2.3)\n", "\n", "# Place them into the comparison distribution dataframe\n", "random_forest_df = item_df.loc[[int(i) for i in top_item_ids]] # Get genres of ratings\n", @@ -3206,7 +3325,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 19, "id": "553f4f32", "metadata": {}, "outputs": [ @@ -3214,124483 +3333,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/11 [00:00\n", " \n", " \n", - " 152215\n", - " 92\n", - " 323\n", - " 2.531322\n", - " 1251\n", - " 1.493225\n", - " 1129\n", - " 1.493225\n", - " 1129\n", - " 1.655457\n", - " 1534\n", + " 0\n", + " 0\n", + " 0\n", + " 2.265398\n", + " 1494\n", + " 2.110029\n", + " 1495\n", + " 2.110029\n", + " 1495\n", + " 2.071523\n", + " 1488\n", " ...\n", - " 1.408509\n", - " 1350\n", - " 1.241687\n", - " 1271\n", - " 1.256181\n", - " 1212\n", - " 1.702941\n", - " 1452\n", - " -0.229753\n", - " 1034\n", + " 3.117306\n", + " 1629\n", + " 2.806793\n", + " 1608\n", + " 3.367846\n", + " 1626\n", + " 4.248022\n", + " 1646\n", + " 3.843781\n", + " 1637\n", " \n", " \n", - " 1368812\n", - " 829\n", - " 133\n", - " -1.044442\n", - " 1051\n", - " -0.343244\n", - " 1317\n", - " -0.343244\n", - " 1317\n", - " -0.385342\n", - " 991\n", + " 1\n", + " 0\n", + " 1\n", + " 0.368650\n", + " 1003\n", + " 0.807322\n", + " 1026\n", + " 0.807322\n", + " 1026\n", + " 1.295580\n", + " 1259\n", " ...\n", - " -0.086752\n", - " 1276\n", - " 0.168047\n", - " 847\n", - " -0.155295\n", - " 1097\n", - " -0.275467\n", - " 1405\n", - " -0.193489\n", - " 61\n", + " 0.588544\n", + " 1007\n", + " -0.075428\n", + " 330\n", + " -0.366332\n", + " 236\n", + " -0.918094\n", + " 44\n", + " 0.795499\n", + " 1260\n", " \n", " \n", - " 1164837\n", - " 705\n", - " 882\n", - " 0.202551\n", - " 834\n", - " 0.034959\n", - " 872\n", - " 0.034959\n", - " 872\n", - " 0.062831\n", - " 198\n", + " 2\n", + " 0\n", + " 2\n", + " 1.420759\n", + " 1314\n", + " 1.648435\n", + " 1378\n", + " 1.648435\n", + " 1378\n", + " 1.456549\n", + " 1313\n", " ...\n", - " -0.025057\n", - " 82\n", - " 0.117943\n", - " 291\n", - " 0.082732\n", - " 987\n", - " 0.011329\n", - " 801\n", - " 0.084295\n", - " 786\n", + " 1.677710\n", + " 1405\n", + " 2.486352\n", + " 1571\n", + " 2.044948\n", + " 1519\n", + " 2.325095\n", + " 1559\n", + " 0.936491\n", + " 1315\n", " \n", " \n", - " 571527\n", - " 346\n", - " 281\n", - " -1.404196\n", - " 372\n", - " -0.345930\n", - " 622\n", - " -0.345930\n", - " 622\n", - " 0.026527\n", - " 1046\n", + " 3\n", + " 0\n", + " 3\n", + " 0.448797\n", + " 1034\n", + " 0.449700\n", + " 776\n", + " 0.449700\n", + " 776\n", + " -0.102504\n", + " 232\n", " ...\n", - " -0.221407\n", - " 102\n", - " -0.191630\n", - " 363\n", - " 0.069267\n", - " 784\n", - " 0.587795\n", - " 236\n", - " -0.335710\n", - " 1153\n", + " -0.210225\n", + " 151\n", + " -0.114164\n", + " 262\n", + " -0.598675\n", + " 130\n", + " -0.331122\n", + " 243\n", + " 0.106251\n", + " 805\n", " \n", " \n", - " 236458\n", - " 143\n", - " 365\n", - " 1.939757\n", - " 1045\n", - " 2.688915\n", - " 1164\n", - " 2.688915\n", - " 1164\n", - " 2.718009\n", - " 165\n", + " 4\n", + " 0\n", + " 4\n", + " 2.548217\n", + " 1545\n", + " 1.759084\n", + " 1407\n", + " 1.759084\n", + " 1407\n", + " 2.151837\n", + " 1502\n", " ...\n", - " 2.691990\n", - " 716\n", - " 2.771286\n", - " 792\n", - " 2.791590\n", - " 1095\n", - " 2.053599\n", - " 1236\n", - " 1.371249\n", - " 839\n", + " 1.539425\n", + " 1375\n", + " 1.157579\n", + " 1271\n", + " 0.386267\n", + " 993\n", + " 1.672385\n", + " 1454\n", + " 0.668341\n", + " 1197\n", " \n", " \n", "\n", @@ -127864,40 +3507,40 @@ "" ], "text/plain": [ - " user_idx item_idx BPR_score BPR_points WMF_score WMF_points \\\n", - "152215 92 323 2.531322 1251 1.493225 1129 \n", - "1368812 829 133 -1.044442 1051 -0.343244 1317 \n", - "1164837 705 882 0.202551 834 0.034959 872 \n", - "571527 346 281 -1.404196 372 -0.345930 622 \n", - "236458 143 365 1.939757 1045 2.688915 1164 \n", + " user_idx item_idx BPR_score BPR_points WMF_score WMF_points \\\n", + "0 0 0 2.265398 1494 2.110029 1495 \n", + "1 0 1 0.368650 1003 0.807322 1026 \n", + "2 0 2 1.420759 1314 1.648435 1378 \n", + "3 0 3 0.448797 1034 0.449700 776 \n", + "4 0 4 2.548217 1545 1.759084 1407 \n", "\n", - " WMF_123_score WMF_123_points WMF_456_score WMF_456_points ... \\\n", - "152215 1.493225 1129 1.655457 1534 ... \n", - "1368812 -0.343244 1317 -0.385342 991 ... \n", - "1164837 0.034959 872 0.062831 198 ... \n", - "571527 -0.345930 622 0.026527 1046 ... \n", - "236458 2.688915 1164 2.718009 165 ... \n", + " WMF_123_score WMF_123_points WMF_456_score WMF_456_points ... \\\n", + "0 2.110029 1495 2.071523 1488 ... \n", + "1 0.807322 1026 1.295580 1259 ... \n", + "2 1.648435 1378 1.456549 1313 ... \n", + "3 0.449700 776 -0.102504 232 ... \n", + "4 1.759084 1407 2.151837 1502 ... \n", "\n", - " WMF_999_score WMF_999_points WMF_k20_score WMF_k20_points \\\n", - "152215 1.408509 1350 1.241687 1271 \n", - "1368812 -0.086752 1276 0.168047 847 \n", - "1164837 -0.025057 82 0.117943 291 \n", - "571527 -0.221407 102 -0.191630 363 \n", - "236458 2.691990 716 2.771286 792 \n", + " WMF_999_score WMF_999_points WMF_k20_score WMF_k20_points \\\n", + "0 3.117306 1629 2.806793 1608 \n", + "1 0.588544 1007 -0.075428 330 \n", + "2 1.677710 1405 2.486352 1571 \n", + "3 -0.210225 151 -0.114164 262 \n", + "4 1.539425 1375 1.157579 1271 \n", "\n", - " WMF_k30_score WMF_k30_points WMF_k40_score WMF_k40_points \\\n", - "152215 1.256181 1212 1.702941 1452 \n", - "1368812 -0.155295 1097 -0.275467 1405 \n", - "1164837 0.082732 987 0.011329 801 \n", - "571527 0.069267 784 0.587795 236 \n", - "236458 2.791590 1095 2.053599 1236 \n", + " WMF_k30_score WMF_k30_points WMF_k40_score WMF_k40_points \\\n", + "0 3.367846 1626 4.248022 1646 \n", + "1 -0.366332 236 -0.918094 44 \n", + "2 2.044948 1519 2.325095 1559 \n", + "3 -0.598675 130 -0.331122 243 \n", + "4 0.386267 993 1.672385 1454 \n", "\n", - " WMF_k50_score WMF_k50_points \n", - "152215 -0.229753 1034 \n", - "1368812 -0.193489 61 \n", - "1164837 0.084295 786 \n", - "571527 -0.335710 1153 \n", - "236458 1.371249 839 \n", + " WMF_k50_score WMF_k50_points \n", + "0 3.843781 1637 \n", + "1 0.795499 1260 \n", + "2 0.936491 1315 \n", + "3 0.106251 805 \n", + "4 0.668341 1197 \n", "\n", "[5 rows x 24 columns]" ] @@ -127925,10 +3568,10 @@ "\n", " # Calculate points for each user\n", " for user_idx in range(train_set.num_users):\n", - " sub_rank_df = rank_df[rank_df[\"user_idx\"] == user_idx] # get all items for a user\n", - " sub_rank_df[name + \"_rank\"] = sub_rank_df[name + \"_score\"].rank(ascending=False).astype(int) # Get Rank where 1 = Top recommendation\n", - " sub_rank_df[name + \"_points\"] = total_items - sub_rank_df[name + \"_rank\"] # Get points by calculating ('Total Item count' - 'Rank')\n", - " point_list.extend(sub_rank_df[name + \"_points\"].values.tolist()) # add points to list\n", + " sub_rank_df = rank_df[rank_df[\"user_idx\"] == user_idx].copy() # get all items for a user\n", + " sub_rank_df.loc[:, name + \"_rank\"] = sub_rank_df[name + \"_score\"].rank(ascending=False).astype(int) # Rank items\n", + " sub_rank_df.loc[:, name + \"_points\"] = total_items - sub_rank_df[name + \"_rank\"] # Calculate points\n", + " point_list.extend(sub_rank_df[name + \"_points\"].values.tolist()) # Add points to the list\n", " \n", " rank_df[name + \"_points\"] = point_list\n", "\n", @@ -127954,7 +3597,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 20, "id": "240fec5f", "metadata": {}, "outputs": [ @@ -128004,124 +3647,124 @@ " \n", " \n", " \n", - " 152215\n", - " 92\n", - " 323\n", - " 2.531322\n", - " 1251\n", - " 1.493225\n", - " 1129\n", - " 1.493225\n", - " 1129\n", - " 1.655457\n", - " 1534\n", + " 0\n", + " 0\n", + " 0\n", + " 2.265398\n", + " 1494\n", + " 2.110029\n", + " 1495\n", + " 2.110029\n", + " 1495\n", + " 2.071523\n", + " 1488\n", " ...\n", - " 1.241687\n", - " 1271\n", - " 1.256181\n", - " 1212\n", - " 1.702941\n", - " 1452\n", - " -0.229753\n", - " 1034\n", - " 2380\n", - " 11989\n", + " 2.806793\n", + " 1608\n", + " 3.367846\n", + " 1626\n", + " 4.248022\n", + " 1646\n", + " 3.843781\n", + " 1637\n", + " 2989\n", + " 14111\n", " \n", " \n", - " 1368812\n", - " 829\n", - " 133\n", - " -1.044442\n", - " 1051\n", - " -0.343244\n", - " 1317\n", - " -0.343244\n", - " 1317\n", - " -0.385342\n", - " 991\n", + " 1\n", + " 0\n", + " 1\n", + " 0.368650\n", + " 1003\n", + " 0.807322\n", + " 1026\n", + " 0.807322\n", + " 1026\n", + " 1.295580\n", + " 1259\n", " ...\n", - " 0.168047\n", - " 847\n", - " -0.155295\n", - " 1097\n", - " -0.275467\n", - " 1405\n", - " -0.193489\n", - " 61\n", - " 2368\n", - " 9449\n", + " -0.075428\n", + " 330\n", + " -0.366332\n", + " 236\n", + " -0.918094\n", + " 44\n", + " 0.795499\n", + " 1260\n", + " 2029\n", + " 7150\n", " \n", " \n", - " 1164837\n", - " 705\n", - " 882\n", - " 0.202551\n", - " 834\n", - " 0.034959\n", - " 872\n", - " 0.034959\n", - " 872\n", - " 0.062831\n", - " 198\n", + " 2\n", + " 0\n", + " 2\n", + " 1.420759\n", + " 1314\n", + " 1.648435\n", + " 1378\n", + " 1.648435\n", + " 1378\n", + " 1.456549\n", + " 1313\n", " ...\n", - " 0.117943\n", - " 291\n", - " 0.082732\n", - " 987\n", - " 0.011329\n", - " 801\n", - " 0.084295\n", - " 786\n", - " 1706\n", - " 4288\n", + " 2.486352\n", + " 1571\n", + " 2.044948\n", + " 1519\n", + " 2.325095\n", + " 1559\n", + " 0.936491\n", + " 1315\n", + " 2692\n", + " 12639\n", " \n", " \n", - " 571527\n", - " 346\n", - " 281\n", - " -1.404196\n", - " 372\n", - " -0.345930\n", - " 622\n", - " -0.345930\n", - " 622\n", - " 0.026527\n", - " 1046\n", + " 3\n", + " 0\n", + " 3\n", + " 0.448797\n", + " 1034\n", + " 0.449700\n", + " 776\n", + " 0.449700\n", + " 776\n", + " -0.102504\n", + " 232\n", " ...\n", - " -0.191630\n", - " 363\n", - " 0.069267\n", - " 784\n", - " 0.587795\n", - " 236\n", - " -0.335710\n", - " 1153\n", - " 994\n", - " 4522\n", + " -0.114164\n", + " 262\n", + " -0.598675\n", + " 130\n", + " -0.331122\n", + " 243\n", + " 0.106251\n", + " 805\n", + " 1810\n", + " 3917\n", " \n", " \n", - " 236458\n", - " 143\n", - " 365\n", - " 1.939757\n", - " 1045\n", - " 2.688915\n", - " 1164\n", - " 2.688915\n", - " 1164\n", - " 2.718009\n", - " 165\n", + " 4\n", + " 0\n", + " 4\n", + " 2.548217\n", + " 1545\n", + " 1.759084\n", + " 1407\n", + " 1.759084\n", + " 1407\n", + " 2.151837\n", + " 1502\n", " ...\n", - " 2.771286\n", - " 792\n", - " 2.791590\n", - " 1095\n", - " 2.053599\n", - " 1236\n", - " 1.371249\n", - " 839\n", - " 2209\n", - " 7257\n", + " 1.157579\n", + " 1271\n", + " 0.386267\n", + " 993\n", + " 1.672385\n", + " 1454\n", + " 0.668341\n", + " 1197\n", + " 2952\n", + " 12121\n", " \n", " \n", "\n", @@ -128129,40 +3772,40 @@ "" ], "text/plain": [ - " user_idx item_idx BPR_score BPR_points WMF_score WMF_points \\\n", - "152215 92 323 2.531322 1251 1.493225 1129 \n", - "1368812 829 133 -1.044442 1051 -0.343244 1317 \n", - "1164837 705 882 0.202551 834 0.034959 872 \n", - "571527 346 281 -1.404196 372 -0.345930 622 \n", - "236458 143 365 1.939757 1045 2.688915 1164 \n", + " user_idx item_idx BPR_score BPR_points WMF_score WMF_points \\\n", + "0 0 0 2.265398 1494 2.110029 1495 \n", + "1 0 1 0.368650 1003 0.807322 1026 \n", + "2 0 2 1.420759 1314 1.648435 1378 \n", + "3 0 3 0.448797 1034 0.449700 776 \n", + "4 0 4 2.548217 1545 1.759084 1407 \n", "\n", - " WMF_123_score WMF_123_points WMF_456_score WMF_456_points ... \\\n", - "152215 1.493225 1129 1.655457 1534 ... \n", - "1368812 -0.343244 1317 -0.385342 991 ... \n", - "1164837 0.034959 872 0.062831 198 ... \n", - "571527 -0.345930 622 0.026527 1046 ... \n", - "236458 2.688915 1164 2.718009 165 ... \n", + " WMF_123_score WMF_123_points WMF_456_score WMF_456_points ... \\\n", + "0 2.110029 1495 2.071523 1488 ... \n", + "1 0.807322 1026 1.295580 1259 ... \n", + "2 1.648435 1378 1.456549 1313 ... \n", + "3 0.449700 776 -0.102504 232 ... \n", + "4 1.759084 1407 2.151837 1502 ... \n", "\n", - " WMF_k20_score WMF_k20_points WMF_k30_score WMF_k30_points \\\n", - "152215 1.241687 1271 1.256181 1212 \n", - "1368812 0.168047 847 -0.155295 1097 \n", - "1164837 0.117943 291 0.082732 987 \n", - "571527 -0.191630 363 0.069267 784 \n", - "236458 2.771286 792 2.791590 1095 \n", + " WMF_k20_score WMF_k20_points WMF_k30_score WMF_k30_points \\\n", + "0 2.806793 1608 3.367846 1626 \n", + "1 -0.075428 330 -0.366332 236 \n", + "2 2.486352 1571 2.044948 1519 \n", + "3 -0.114164 262 -0.598675 130 \n", + "4 1.157579 1271 0.386267 993 \n", "\n", - " WMF_k40_score WMF_k40_points WMF_k50_score WMF_k50_points \\\n", - "152215 1.702941 1452 -0.229753 1034 \n", - "1368812 -0.275467 1405 -0.193489 61 \n", - "1164837 0.011329 801 0.084295 786 \n", - "571527 0.587795 236 -0.335710 1153 \n", - "236458 2.053599 1236 1.371249 839 \n", + " WMF_k40_score WMF_k40_points WMF_k50_score WMF_k50_points Borda Count \\\n", + "0 4.248022 1646 3.843781 1637 2989 \n", + "1 -0.918094 44 0.795499 1260 2029 \n", + "2 2.325095 1559 0.936491 1315 2692 \n", + "3 -0.331122 243 0.106251 805 1810 \n", + "4 1.672385 1454 0.668341 1197 2952 \n", "\n", - " Borda Count WMF Borda Count \n", - "152215 2380 11989 \n", - "1368812 2368 9449 \n", - "1164837 1706 4288 \n", - "571527 994 4522 \n", - "236458 2209 7257 \n", + " WMF Borda Count \n", + "0 14111 \n", + "1 7150 \n", + "2 12639 \n", + "3 3917 \n", + "4 12121 \n", "\n", "[5 rows x 26 columns]" ] @@ -128226,105 +3869,105 @@ " 0\n", " 0\n", " 381\n", - " 2.109998\n", - " 2.071429\n", - " 1.903534\n", - " 2.302128\n", - " 3.117219\n", - " 2.806823\n", - " 3.367875\n", - " 4.242294\n", - " 3.840917\n", - " 4.827443\n", - " 1.30\n", - " 2.531322\n", - " 1.493225\n", - " 2380\n", - " 11989\n", + " 2.110029\n", + " 2.071523\n", + " 1.903551\n", + " 2.302063\n", + " 3.117306\n", + " 2.806793\n", + " 3.367846\n", + " 4.248022\n", + " 3.843781\n", + " 4.833850\n", + " 4.34\n", + " 2.265398\n", + " 2.110029\n", + " 2989\n", + " 14111\n", " \n", " \n", " 1\n", " 0\n", " 1\n", " 602\n", - " 0.807329\n", - " 1.295570\n", - " 0.918373\n", - " 0.553849\n", - " 0.588498\n", - " -0.075745\n", - " -0.366928\n", - " -0.917656\n", - " 0.786125\n", - " 0.032359\n", - " 1.02\n", - " -1.044442\n", - " -0.343244\n", - " 2368\n", - " 9449\n", + " 0.807322\n", + " 1.295580\n", + " 0.918263\n", + " 0.553786\n", + " 0.588544\n", + " -0.075428\n", + " -0.366332\n", + " -0.918094\n", + " 0.795499\n", + " 0.039174\n", + " 1.48\n", + " 0.368650\n", + " 0.807322\n", + " 2029\n", + " 7150\n", " \n", " \n", " 2\n", " 0\n", " 2\n", " 431\n", - " 1.648487\n", - " 1.456568\n", - " 1.591698\n", - " 1.270859\n", - " 1.677644\n", - " 2.485114\n", - " 2.034434\n", - " 2.333905\n", - " 0.949877\n", - " 1.547700\n", - " 1.92\n", - " 0.202551\n", - " 0.034959\n", - " 1706\n", - " 4288\n", + " 1.648435\n", + " 1.456549\n", + " 1.591913\n", + " 1.271597\n", + " 1.677710\n", + " 2.486352\n", + " 2.044948\n", + " 2.325095\n", + " 0.936491\n", + " 1.534016\n", + " 2.10\n", + " 1.420759\n", + " 1.648435\n", + " 2692\n", + " 12639\n", " \n", " \n", " 3\n", " 0\n", " 3\n", " 875\n", - " 0.449641\n", - " -0.102344\n", - " 0.439881\n", - " 0.164935\n", - " -0.209975\n", - " -0.113291\n", - " -0.599113\n", - " -0.331742\n", - " 0.105063\n", - " -0.185502\n", - " 1.04\n", - " -1.404196\n", - " -0.345930\n", - " 994\n", - " 4522\n", + " 0.449700\n", + " -0.102504\n", + " 0.439818\n", + " 0.164991\n", + " -0.210225\n", + " -0.114164\n", + " -0.598675\n", + " -0.331122\n", + " 0.106251\n", + " -0.184401\n", + " 1.08\n", + " 0.448797\n", + " 0.449700\n", + " 1810\n", + " 3917\n", " \n", " \n", " 4\n", " 0\n", " 4\n", " 182\n", - " 1.758960\n", - " 2.151483\n", - " 2.659284\n", - " 1.573445\n", - " 1.539223\n", - " 1.145167\n", - " 0.355686\n", - " 1.674228\n", - " 0.677513\n", - " 0.840872\n", - " 1.00\n", - " 1.939757\n", - " 2.688915\n", - " 2209\n", - " 7257\n", + " 1.759084\n", + " 2.151837\n", + " 2.659510\n", + " 1.573485\n", + " 1.539425\n", + " 1.157579\n", + " 0.386267\n", + " 1.672385\n", + " 0.668341\n", + " 0.835969\n", + " 1.50\n", + " 2.548217\n", + " 1.759084\n", + " 2952\n", + " 12121\n", " \n", " \n", "\n", @@ -128332,32 +3975,32 @@ ], "text/plain": [ " user_idx item_idx item_id WMF_123_score WMF_456_score WMF_789_score \\\n", - "0 0 0 381 2.109998 2.071429 1.903534 \n", - "1 0 1 602 0.807329 1.295570 0.918373 \n", - "2 0 2 431 1.648487 1.456568 1.591698 \n", - "3 0 3 875 0.449641 -0.102344 0.439881 \n", - "4 0 4 182 1.758960 2.151483 2.659284 \n", + "0 0 0 381 2.110029 2.071523 1.903551 \n", + "1 0 1 602 0.807322 1.295580 0.918263 \n", + "2 0 2 431 1.648435 1.456549 1.591913 \n", + "3 0 3 875 0.449700 -0.102504 0.439818 \n", + "4 0 4 182 1.759084 2.151837 2.659510 \n", "\n", " WMF_888_score WMF_999_score WMF_k20_score WMF_k30_score WMF_k40_score \\\n", - "0 2.302128 3.117219 2.806823 3.367875 4.242294 \n", - "1 0.553849 0.588498 -0.075745 -0.366928 -0.917656 \n", - "2 1.270859 1.677644 2.485114 2.034434 2.333905 \n", - "3 0.164935 -0.209975 -0.113291 -0.599113 -0.331742 \n", - "4 1.573445 1.539223 1.145167 0.355686 1.674228 \n", + "0 2.302063 3.117306 2.806793 3.367846 4.248022 \n", + "1 0.553786 0.588544 -0.075428 -0.366332 -0.918094 \n", + "2 1.271597 1.677710 2.486352 2.044948 2.325095 \n", + "3 0.164991 -0.210225 -0.114164 -0.598675 -0.331122 \n", + "4 1.573485 1.539425 1.157579 0.386267 1.672385 \n", "\n", " WMF_k50_score WMF Linear Regression WMF Random Forest BPR_score \\\n", - "0 3.840917 4.827443 1.30 2.531322 \n", - "1 0.786125 0.032359 1.02 -1.044442 \n", - "2 0.949877 1.547700 1.92 0.202551 \n", - "3 0.105063 -0.185502 1.04 -1.404196 \n", - "4 0.677513 0.840872 1.00 1.939757 \n", + "0 3.843781 4.833850 4.34 2.265398 \n", + "1 0.795499 0.039174 1.48 0.368650 \n", + "2 0.936491 1.534016 2.10 1.420759 \n", + "3 0.106251 -0.184401 1.08 0.448797 \n", + "4 0.668341 0.835969 1.50 2.548217 \n", "\n", " WMF_score Borda Count WMF Borda Count \n", - "0 1.493225 2380 11989 \n", - "1 -0.343244 2368 9449 \n", - "2 0.034959 1706 4288 \n", - "3 -0.345930 994 4522 \n", - "4 2.688915 2209 7257 " + "0 2.110029 2989 14111 \n", + "1 0.807322 2029 7150 \n", + "2 1.648435 2692 12639 \n", + "3 0.449700 1810 3917 \n", + "4 1.759084 2952 12121 " ] }, "metadata": {}, @@ -128365,7 +4008,6 @@ } ], "source": [ - "\n", "borda_count_models = [bpr_model, wmf_model]\n", "rank_df[\"Borda Count\"] = rank_df[[model.name + \"_points\" for model in borda_count_models]].sum(axis=1) # Sum up points of BPR and WMF\n", "\n", @@ -128398,7 +4040,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 21, "id": "916b390b", "metadata": {}, "outputs": [ @@ -128406,7 +4048,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 6/6 [00:29<00:00, 4.93s/it]\n" + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:15<00:00, 2.63s/it]\n" ] }, { @@ -128448,14 +4090,14 @@ " \n", " 0\n", " Precision@50\n", - " 0.012979\n", - " 0.013511\n", + " 0.099574\n", + " 0.103213\n", " \n", " \n", " 1\n", " Recall@50\n", - " 0.030273\n", - " 0.030955\n", + " 0.363850\n", + " 0.372287\n", " \n", " \n", "\n", @@ -128463,8 +4105,8 @@ ], "text/plain": [ " Metrics BPR_score WMF_score\n", - "0 Precision@50 0.012979 0.013511\n", - "1 Recall@50 0.030273 0.030955" + "0 Precision@50 0.099574 0.103213\n", + "1 Recall@50 0.363850 0.372287" ] }, "metadata": {}, @@ -128510,16 +4152,16 @@ " \n", " 0\n", " Precision@50\n", - " 0.015085\n", - " 0.075894\n", - " 0.013149\n", + " 0.099745\n", + " 0.075809\n", + " 0.069128\n", " \n", " \n", " 1\n", " Recall@50\n", - " 0.034803\n", - " 0.312299\n", - " 0.029173\n", + " 0.379883\n", + " 0.312031\n", + " 0.283792\n", " \n", " \n", "\n", @@ -128527,8 +4169,8 @@ ], "text/plain": [ " Metrics WMF Borda Count WMF Linear Regression WMF Random Forest\n", - "0 Precision@50 0.015085 0.075894 0.013149\n", - "1 Recall@50 0.034803 0.312299 0.029173" + "0 Precision@50 0.099745 0.075809 0.069128\n", + "1 Recall@50 0.379883 0.312031 0.283792" ] }, "metadata": {}, @@ -128544,13 +4186,12 @@ "\n", "test_users = set(test_set.uir_tuple[0])\n", "for model in tqdm(models):\n", - " all_df = all_df.sort_values(model, ascending=False) # sort by predicted ratings\n", - " predicted_ids = [all_df[all_df['user_idx'] == uidx]['item_idx'].values[:TOPK].astype(int) for uidx in range(train_set.num_users)]\n", + " sorted_df = all_df.sort_values(model, ascending=False) # sort by predicted ratings\n", " precisions, recalls = [], []\n", " \n", " for uidx in test_users:\n", " true_top_k = test_set.user_data[uidx][0] # ground truth data\n", - " predicted_top_k = predicted_ids[uidx].tolist() # predicted ranking data\n", + " predicted_top_k = sorted_df[sorted_df['user_idx'] == uidx]['item_idx'].values[:TOPK].astype(int)\n", " # Precision@K\n", " precision = len(set(true_top_k) & set(predicted_top_k)) / len(predicted_top_k)\n", " precisions.append(precision)\n", @@ -128600,7 +4241,7 @@ ], "metadata": { "kernelspec": { - "display_name": "gcmc-pytorch", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -128614,7 +4255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.10.14" } }, "nbformat": 4,