From bdba8132cfea77b1d6a2d767b1826deb61fc6ec4 Mon Sep 17 00:00:00 2001 From: David Date: Fri, 26 Apr 2024 18:44:42 -0400 Subject: [PATCH] setting gpu to auto --- setup.py | 4 +--- src/neutron/interactive_model.py | 31 ++++++++----------------------- src/neutron/server.py | 3 ++- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 34bd5f2..d67c1ab 100644 --- a/setup.py +++ b/setup.py @@ -36,9 +36,7 @@ "langchain", "regex", "argparse", - "typing-extensions" - - + "typing-extensions", ], entry_points={ "console_scripts": [ diff --git a/src/neutron/interactive_model.py b/src/neutron/interactive_model.py index b71707b..91ce6a5 100644 --- a/src/neutron/interactive_model.py +++ b/src/neutron/interactive_model.py @@ -21,7 +21,7 @@ class InteractiveModel: def __init__(self): # Device configuration - + utilities.check_new_pypi_version() utilities.ensure_model_folder_exists("neutron_model") utilities.ensure_model_folder_exists("neutron_chroma.db") @@ -37,27 +37,13 @@ def __init__(self): model_max_length=8192, low_cpu_mem_usage=True, ) - total_memory_gb = 0 - if torch.cuda.is_available(): - total_memory_gb = torch.cuda.get_device_properties(0).total_memory / ( - 1024**3 - ) # Convert bytes to GB - print(f"total GPU memory available {total_memory_gb}") - if total_memory_gb < 24: - print("There isnt enough GPU memory, will use CPU") - if total_memory_gb >= 24: - self.model = AutoModelForCausalLM.from_pretrained( - utilities.return_path("neutron_model"), - quantization_config=bnb_config, - device_map={"": 0}, - ) - else: - self.model = AutoModelForCausalLM.from_pretrained( - utilities.return_path("neutron_model"), - low_cpu_mem_usage=True, - quantization_config=bnb_config, - ) + self.model = AutoModelForCausalLM.from_pretrained( + utilities.return_path("neutron_model"), + low_cpu_mem_usage=True, + quantization_config=bnb_config, + device_map="auto", + ) # Pipeline configuration self.pipe = pipeline( "text-generation", @@ -79,7 +65,6 @@ def __init__(self): persist_directory="neutron_chroma.db", ) self.retriever = self.db.as_retriever(search_type="mmr") - # self.retriever = self.db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 1.0}) self.template = ChatPromptTemplate.from_template( """ @@ -111,7 +96,7 @@ def __init__(self): def invoke(self, question: str): self.search_results = self.search.run(question) - # print(self.search_results) + return self.chain.invoke(question) def search_duck(self, question: str): diff --git a/src/neutron/server.py b/src/neutron/server.py index 001fc94..3140e49 100644 --- a/src/neutron/server.py +++ b/src/neutron/server.py @@ -1,8 +1,9 @@ import argparse import os from typing import Any, Dict -import torch + import psutil +import torch from fastapi import FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel