-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/locust
- Loading branch information
Showing
14 changed files
with
412 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# | ||
name: Create and publish a Docker image | ||
|
||
# Configures this workflow to run every time a change is pushed to the branch called `release`. | ||
on: | ||
push: | ||
branches: | ||
- master | ||
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds. | ||
env: | ||
REGISTRY: ghcr.io | ||
IMAGE_NAME: ${{ github.repository }} | ||
|
||
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. | ||
jobs: | ||
build-and-push-image: | ||
runs-on: ubuntu-latest | ||
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. | ||
permissions: | ||
contents: read | ||
packages: write | ||
# | ||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v4 | ||
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. | ||
- name: Log in to the Container registry | ||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 | ||
with: | ||
registry: ${{ env.REGISTRY }} | ||
username: ${{ github.actor }} | ||
password: ${{ secrets.GITHUB_TOKEN }} | ||
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. | ||
- name: Extract metadata (tags, labels) for Docker | ||
id: meta | ||
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 | ||
with: | ||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} | ||
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. | ||
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository. | ||
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. | ||
- name: Build and push Docker image | ||
uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 | ||
with: | ||
context: API/ | ||
push: true | ||
tags: ${{ steps.meta.outputs.tags }} | ||
labels: ${{ steps.meta.outputs.labels }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
name: Docker Image CI | ||
|
||
on: | ||
push: | ||
branches: [ "master" ] | ||
pull_request: | ||
branches: [ "master" ] | ||
|
||
jobs: | ||
|
||
build: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Build the Docker image | ||
run: docker build ./API/. --file API/Dockerfile --tag my-image-name:$(date +%s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
ARG PYTHON_VERSION=3.10 | ||
|
||
FROM docker.io/python:${PYTHON_VERSION}-slim-bookworm | ||
|
||
RUN apt-get update \ | ||
&& apt-get -y upgrade \ | ||
&& apt-get --no-install-recommends -y install \ | ||
build-essential libgdal-dev libboost-numpy-dev | ||
|
||
COPY requirements.txt requirements.txt | ||
|
||
RUN \ | ||
python3 -m pip install --upgrade pip \ | ||
&& python3 -m pip install -r requirements.txt | ||
|
||
WORKDIR /app | ||
|
||
COPY main.py /app/main.py | ||
|
||
EXPOSE 8000 | ||
|
||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## FastAPI Prediction API | ||
|
||
Contains a FastAPI-based API for making predictions using a fAIr model. It provides an endpoint to predict results based on specified parameters. | ||
|
||
### Prerequisites | ||
|
||
- Docker installed on your system | ||
|
||
### Getting Started | ||
|
||
1. Clone Repo and Navigate to /API | ||
|
||
```bash | ||
git clone https://github.com/kshitijrajsharma/fairpredictor.git | ||
cd API | ||
``` | ||
|
||
2. Build Docker Image | ||
|
||
```bash | ||
docker build -t predictor-api . | ||
``` | ||
|
||
3. Run Docker Container | ||
|
||
```bash | ||
docker run -p 8080:8000 predictor-api | ||
``` | ||
|
||
4. API Documentation | ||
|
||
- Redocly Documentation - > Go to your_API_url/redoc : for eg [localhost:redoc](http://localhost:8080/redoc) | ||
- Swagger Documentation - > Go to your_API_url/docs : for eg [localhost:docs](http://localhost:8080/docs#/default/predict_api_predict__post) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import os | ||
import tempfile | ||
from typing import List, Optional | ||
|
||
import requests | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from pydantic import BaseModel, Field, PositiveFloat, validator | ||
|
||
from predictor import predict | ||
|
||
app = FastAPI( | ||
title="fAIr Prediction API", | ||
description="Standalone API for Running .h5, .tf, .tflite Model Predictions", | ||
) | ||
|
||
|
||
origins = ["*"] | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=origins, | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
|
||
class PredictionRequest(BaseModel): | ||
""" | ||
Request model for the prediction endpoint. | ||
Example : | ||
{ | ||
"bbox": [ | ||
100.56228021333352, | ||
13.685230854641182, | ||
100.56383321235313, | ||
13.685961853747969 | ||
], | ||
"checkpoint": "https://fair-dev.hotosm.org/api/v1/workspace/download/dataset_58/output/training_324//checkpoint.tflite", | ||
"zoom_level": 20, | ||
"source": "https://tiles.openaerialmap.org/6501a65c0906de000167e64d/0/6501a65c0906de000167e64e/{z}/{x}/{y}" | ||
} | ||
""" | ||
|
||
bbox: List[float] | ||
|
||
checkpoint: str = Field( | ||
..., | ||
example="path/to/model.tflite or https://example.com/model.tflite", | ||
description="Path or URL to the machine learning model file.", | ||
) | ||
|
||
zoom_level: int = Field( | ||
..., | ||
description="Zoom level of the tiles to be used for prediction.", | ||
) | ||
|
||
source: str = Field( | ||
..., | ||
description="Your Image URL on which you want to detect features.", | ||
) | ||
|
||
use_josm_q: Optional[bool] = Field( | ||
False, | ||
description="Indicates whether to use JOSM query. Defaults to False.", | ||
) | ||
|
||
merge_adjacent_polygons: Optional[bool] = Field( | ||
True, | ||
description="Merges adjacent self-intersecting or containing each other polygons. Defaults to True.", | ||
) | ||
|
||
confidence: Optional[int] = Field( | ||
50, | ||
description="Threshold probability for filtering out low-confidence predictions. Defaults to 50.", | ||
) | ||
|
||
max_angle_change: Optional[int] = Field( | ||
15, | ||
description="Maximum angle change parameter for prediction. Defaults to 15.", | ||
) | ||
|
||
skew_tolerance: Optional[int] = Field( | ||
15, | ||
description="Skew tolerance parameter for prediction. Defaults to 15.", | ||
) | ||
|
||
tolerance: Optional[float] = Field( | ||
0.5, | ||
description="Tolerance parameter for simplifying polygons. Defaults to 0.5.", | ||
) | ||
|
||
area_threshold: Optional[float] = Field( | ||
3, | ||
description="Threshold for filtering polygon areas. Defaults to 3.", | ||
) | ||
|
||
tile_overlap_distance: Optional[float] = Field( | ||
0.15, | ||
description="Provides tile overlap distance to remove the strip between predictions. Defaults to 0.15.", | ||
) | ||
|
||
@validator( | ||
"max_angle_change", | ||
"skew_tolerance", | ||
) | ||
def validate_values(cls, value): | ||
if value is not None: | ||
if value < 0 or value > 45: | ||
raise ValueError(f"Value should be between 0 and 45: {value}") | ||
return value | ||
|
||
@validator("tolerance") | ||
def validate_tolerance(cls, value): | ||
if value is not None: | ||
if value < 0 or value > 10: | ||
raise ValueError(f"Value should be between 0 and 10: {value}") | ||
return value | ||
|
||
@validator("tile_overlap_distance") | ||
def validate_tile_overlap_distance(cls, value): | ||
if value is not None: | ||
if value < 0 or value > 1: | ||
raise ValueError(f"Value should be between 0 and 1: {value}") | ||
return value | ||
|
||
@validator("area_threshold") | ||
def validate_area_threshold(cls, value): | ||
if value is not None: | ||
if value < 0 or value > 20: | ||
raise ValueError(f"Value should be between 0 and 20: {value}") | ||
return value | ||
|
||
@validator("confidence") | ||
def validate_confidence(cls, value): | ||
if value is not None: | ||
if value < 0 or value > 100: | ||
raise ValueError(f"Value should be between 0 and 100: {value}") | ||
return value / 100 | ||
|
||
@validator("bbox") | ||
def validate_bbox(cls, value): | ||
if len(value) != 4: | ||
raise ValueError("bbox should have exactly 4 elements") | ||
return value | ||
|
||
@validator("zoom_level") | ||
def validate_zoom_level(cls, value): | ||
if value < 18 or value > 22: | ||
raise ValueError("Zoom level should be between 18 and 22") | ||
return value | ||
|
||
@validator("checkpoint") | ||
def validate_checkpoint(cls, value): | ||
""" | ||
Validates checkpoint parameter. If URL, download the file to temp directory. | ||
""" | ||
if value.startswith("http"): | ||
response = requests.get(value) | ||
if response.status_code != 200: | ||
raise ValueError( | ||
"Failed to download model checkpoint from the provided URL" | ||
) | ||
_, temp_file_path = tempfile.mkstemp(suffix=".tflite") | ||
with open(temp_file_path, "wb") as f: | ||
f.write(response.content) | ||
return temp_file_path | ||
elif not os.path.exists(value): | ||
raise ValueError("Model checkpoint file not found") | ||
return value | ||
|
||
|
||
@app.post("/predict/") | ||
async def predict_api(request: PredictionRequest): | ||
""" | ||
Endpoint to predict results based on specified parameters. | ||
Parameters: | ||
- `request` (PredictionRequest): Request body containing prediction parameters. | ||
Returns: | ||
- Predicted results. | ||
""" | ||
try: | ||
predictions = predict( | ||
bbox=request.bbox, | ||
model_path=request.checkpoint, | ||
zoom_level=request.zoom_level, | ||
tms_url=request.source, | ||
tile_size=256, | ||
confidence=request.confidence, | ||
tile_overlap_distance=request.tile_overlap_distance, | ||
merge_adjancent_polygons=request.merge_adjacent_polygons, | ||
max_angle_change=request.max_angle_change, | ||
skew_tolerance=request.skew_tolerance, | ||
tolerance=request.tolerance, | ||
area_threshold=request.area_threshold, | ||
) | ||
return predictions | ||
except Exception as e: | ||
return {"error": str(e)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
fastapi==0.103.2 | ||
uvicorn==0.22.0 | ||
fairpredictor | ||
tflite-runtime==2.14.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.