Skip to content

Commit

Permalink
setting gpu to auto
Browse files Browse the repository at this point in the history
  • Loading branch information
berylliumsec-handler committed Apr 26, 2024
1 parent 04f135b commit bdba813
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 27 deletions.
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
"langchain",
"regex",
"argparse",
"typing-extensions"


"typing-extensions",
],
entry_points={
"console_scripts": [
Expand Down
31 changes: 8 additions & 23 deletions src/neutron/interactive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class InteractiveModel:
def __init__(self):
# Device configuration

utilities.check_new_pypi_version()
utilities.ensure_model_folder_exists("neutron_model")
utilities.ensure_model_folder_exists("neutron_chroma.db")
Expand All @@ -37,27 +37,13 @@ def __init__(self):
model_max_length=8192,
low_cpu_mem_usage=True,
)
total_memory_gb = 0
if torch.cuda.is_available():
total_memory_gb = torch.cuda.get_device_properties(0).total_memory / (
1024**3
) # Convert bytes to GB
print(f"total GPU memory available {total_memory_gb}")
if total_memory_gb < 24:
print("There isnt enough GPU memory, will use CPU")

if total_memory_gb >= 24:
self.model = AutoModelForCausalLM.from_pretrained(
utilities.return_path("neutron_model"),
quantization_config=bnb_config,
device_map={"": 0},
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
utilities.return_path("neutron_model"),
low_cpu_mem_usage=True,
quantization_config=bnb_config,
)
self.model = AutoModelForCausalLM.from_pretrained(
utilities.return_path("neutron_model"),
low_cpu_mem_usage=True,
quantization_config=bnb_config,
device_map="auto",
)
# Pipeline configuration
self.pipe = pipeline(
"text-generation",
Expand All @@ -79,7 +65,6 @@ def __init__(self):
persist_directory="neutron_chroma.db",
)
self.retriever = self.db.as_retriever(search_type="mmr")
# self.retriever = self.db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 1.0})

self.template = ChatPromptTemplate.from_template(
"""
Expand Down Expand Up @@ -111,7 +96,7 @@ def __init__(self):

def invoke(self, question: str):
self.search_results = self.search.run(question)
# print(self.search_results)

return self.chain.invoke(question)

def search_duck(self, question: str):
Expand Down
3 changes: 2 additions & 1 deletion src/neutron/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import argparse
import os
from typing import Any, Dict
import torch

import psutil
import torch
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
Expand Down

0 comments on commit bdba813

Please sign in to comment.