Skip to content

Commit

Permalink
Merge pull request nasa-gcn#2073 from jak574/across-api-async-changes
Browse files Browse the repository at this point in the history
ACROSS API: Move to using aioboto3 and make things as async as possible
  • Loading branch information
jak574 authored Mar 12, 2024
2 parents f2ee090 + f21c8fc commit f05d065
Show file tree
Hide file tree
Showing 18 changed files with 302 additions and 237 deletions.
18 changes: 12 additions & 6 deletions python/across_api/across/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from fastapi import Depends, Query, Security

from ..base.api import app
from ..auth.api import scope_authorize
from ..base.api import app
from .hello import Hello
from .resolve import Resolve
from .schema import HelloSchema, ResolveSchema
Expand Down Expand Up @@ -40,12 +40,14 @@ def your_name(

# API End points
@app.get("/")
def hello(name: YourNameDep) -> HelloSchema:
async def hello(name: YourNameDep) -> HelloSchema:
"""
This function returns a JSON response with a greeting message and an optional name parameter.
If the name parameter is provided, the greeting message will include the name.
"""
return Hello(name=name).schema
hello = Hello(name=name)
await hello.get()
return hello.schema


@app.get(
Expand All @@ -59,12 +61,16 @@ async def secure_hello(name: YourNameDep) -> HelloSchema:
This function returns a JSON response with a greeting message and an optional name parameter.
If the name parameter is provided, the greeting message will include the name.
"""
return Hello(name=name).schema
hello = Hello(name=name)
await hello.get()
return hello.schema


@app.get("/across/resolve")
def resolve(name: SourceNameDep) -> ResolveSchema:
async def resolve(name: SourceNameDep) -> ResolveSchema:
"""
Resolve the name of an astronomical object to its coordinates.
"""
return Resolve(name=name).schema
resolve = Resolve(name=name)
await resolve.get()
return resolve.schema
4 changes: 1 addition & 3 deletions python/across_api/across/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ def __init__(self, name: Optional[str] = None):
Your name
"""
self.name = name
if self.validate_get():
self.get()

def get(self) -> bool:
async def get(self) -> bool:
"""
GET method for ACROSS API Hello class.
Expand Down
13 changes: 6 additions & 7 deletions python/across_api/across/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
from typing import Optional, Tuple

import requests
import httpx
from astropy.coordinates.name_resolve import NameResolveError # type: ignore
from astropy.coordinates.sky_coordinate import SkyCoord # type: ignore

Expand All @@ -15,7 +15,7 @@
ANTARES_URL = "https://api.antares.noirlab.edu/v1/loci"


def antares_radec(ztf_id: str) -> Tuple[Optional[float], Optional[float]]:
async def antares_radec(ztf_id: str) -> Tuple[Optional[float], Optional[float]]:
"""
Query ANTARES API to find RA/Dec of a given ZTF source
Expand All @@ -38,7 +38,8 @@ def antares_radec(ztf_id: str) -> Tuple[Optional[float], Optional[float]]:
"sort": "-properties.newest_alert_observation_time",
"elasticsearch_query[locus_listing]": search_query,
}
r = requests.get(ANTARES_URL, params=params)
async with httpx.AsyncClient() as client:
r = await client.get(ANTARES_URL, params=params)
r.raise_for_status()
antares_data = r.json()
if antares_data["meta"]["count"] > 0:
Expand Down Expand Up @@ -91,10 +92,8 @@ def __init__(self, name: str):
self.dec = None
self.name = name
self.resolver = None
if self.validate_get():
self.get()

def get(self) -> bool:
async def get(self) -> bool:
"""
Retrieves the RA and Dec coordinates for a given name.
Expand All @@ -109,7 +108,7 @@ def get(self) -> bool:
"""Do a name search"""
# Check against the ANTARES broker
if "ZTF" in self.name:
ra, dec = antares_radec(self.name)
ra, dec = await antares_radec(self.name)
if ra is not None:
self.ra, self.dec = ra, dec
self.resolver = "ANTARES"
Expand Down
16 changes: 16 additions & 0 deletions python/across_api/base/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

import aioboto3 # type: ignore[import]
from arc._lib import get_ports, use_aws # type: ignore[import]


def dynamodb_resource():
dynamodb_session = aioboto3.Session()
if use_aws():
return dynamodb_session.resource("dynamodb")
else:
port = get_ports()["tables"]
region_name = os.environ.get("AWS_REGION", "us-east-1")
return dynamodb_session.resource(
"dynamodb", endpoint_url=f"http://localhost:{port}", region_name=region_name
)
11 changes: 4 additions & 7 deletions python/across_api/base/ephem.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,13 @@ def ephindex(self, t: Time) -> int:
return index

def __init__(self, begin: Time, end: Time, stepsize: u.Quantity = 60 * u.s):
# Check if TLE is loaded
if self.tle is None:
raise HTTPException(
status_code=404, detail="No TLE available for this epoch"
)

# Parse inputs, round begin and end to stepsize
self.begin = round_time(begin, stepsize)
self.end = round_time(end, stepsize)
self.stepsize = stepsize

# Check the TLE is available
def compute_ephem(self) -> bool:
# Check if TLE is loaded
if self.tle is None:
raise HTTPException(
status_code=404, detail="No TLE available for this epoch"
Expand Down Expand Up @@ -163,3 +158,5 @@ def __init__(self, begin: Time, end: Time, stepsize: u.Quantity = 60 * u.s):
self.earthsize = self.earth_radius * np.ones(len(self))
else:
self.earthsize = np.arcsin(R_earth / dist)

return True
45 changes: 24 additions & 21 deletions python/across_api/base/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.


from arc.tables import name # type: ignore[import]
from datetime import datetime
from typing import Annotated, Any, List, Optional, Union

import astropy.units as u # type: ignore
from arc import tables # type: ignore[import]
from astropy.coordinates import Latitude, Longitude # type: ignore[import]
from astropy.time import Time # type: ignore
from pydantic import (
Expand All @@ -21,6 +20,8 @@
model_validator,
)

from .database import dynamodb_resource

# Define a Pydantic type for astropy Time objects, which will be serialized as
# a naive UTC datetime object, or a string in ISO format for JSON.
AstropyTime = Annotated[
Expand Down Expand Up @@ -228,7 +229,7 @@ def epoch(self) -> AstropyTime:
return Time(f"{year}-01-01", scale="utc") + (day_of_year - 1) * u.day

@classmethod
def find_tles_between_epochs(
async def find_tles_between_epochs(
cls, satname: str, start_epoch: Time, end_epoch: Time
) -> List[Any]:
"""
Expand All @@ -248,25 +249,27 @@ def find_tles_between_epochs(
-------
A list of TLEEntry objects between the specified epochs.
"""
table = tables.table(cls.__tablename__)

# Query the table for TLEs between the two epochs
response = table.query(
KeyConditionExpression="satname = :satname AND epoch BETWEEN :start_epoch AND :end_epoch",
ExpressionAttributeValues={
":satname": satname,
":start_epoch": str(start_epoch.utc.datetime),
":end_epoch": str(end_epoch.utc.datetime),
},
)

# Convert the response into a list of TLEEntry objects and return them
return [cls(**item) for item in response["Items"]]

def write(self) -> None:
async with dynamodb_resource() as dynamodb:
table = await dynamodb.Table(name(cls.__tablename__))

# Query the table for TLEs between the two epochs
response = await table.query(
KeyConditionExpression="satname = :satname AND epoch BETWEEN :start_epoch AND :end_epoch",
ExpressionAttributeValues={
":satname": satname,
":start_epoch": str(start_epoch.utc.datetime),
":end_epoch": str(end_epoch.utc.datetime),
},
)

# Convert the response into a list of TLEEntry objects and return them
return [cls(**item) for item in response["Items"]]

async def write(self) -> None:
"""Write the TLE entry to the database."""
table = tables.table(self.__tablename__)
table.put_item(Item=self.model_dump(mode="json"))
async with dynamodb_resource() as dynamodb:
table = await dynamodb.Table(name(self.__tablename__))
await table.put_item(Item=self.model_dump(mode="json"))


class TLESchema(BaseSchema):
Expand Down
59 changes: 30 additions & 29 deletions python/across_api/base/tle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import os
from typing import List, Optional

import requests
import httpx
from astropy.time import Time # type: ignore
from astropy.units import Quantity # type: ignore
from requests import HTTPError
from spacetrack import SpaceTrackClient # type: ignore
from httpx import HTTPError
from spacetrack import AsyncSpaceTrackClient # type: ignore

from .common import ACROSSAPIBase
from .schema import TLEEntry, TLEGetSchema, TLESchema
Expand Down Expand Up @@ -102,10 +102,8 @@ def __init__(self, epoch: Time, tle: Optional[TLEEntry] = None):
self.tles = [tle]
else:
self.tles = []
if self.validate_get():
self.get()

def read_tle_db(self) -> bool:
async def read_tle_db(self) -> bool:
"""
Read the best TLE for a given epoch from the local database of TLEs
Expand All @@ -115,15 +113,15 @@ def read_tle_db(self) -> bool:
"""
# Read TLEs from the database for a given `tle_name` and epoch within
# the allowed range
self.tles = TLEEntry.find_tles_between_epochs(
self.tles = await TLEEntry.find_tles_between_epochs(
self.tle_name,
self.epoch - self.tle_bad,
self.epoch + self.tle_bad,
)

return True

def read_tle_web(self) -> bool:
async def read_tle_web(self) -> bool:
"""
Read TLE from dedicated weblink.
Expand All @@ -143,7 +141,8 @@ def read_tle_web(self) -> bool:
return False

# Download TLE from internet
r = requests.get(self.tle_url)
async with httpx.AsyncClient() as client:
r = await client.get(self.tle_url)
try:
# Check for HTTP errors
r.raise_for_status()
Expand Down Expand Up @@ -172,7 +171,7 @@ def read_tle_web(self) -> bool:

return False

def read_tle_space_track(self) -> bool:
async def read_tle_space_track(self) -> bool:
"""
Read TLE from Space-Track.org.
Expand All @@ -195,19 +194,20 @@ def read_tle_space_track(self) -> bool:
epoch_stop = self.epoch + self.tle_bad

# Log into space-track.org
st = SpaceTrackClient(
async with AsyncSpaceTrackClient(
identity=os.environ.get("SPACE_TRACK_USER"),
password=os.environ.get("SPACE_TRACK_PASS"),
)

# Fetch the TLEs between the requested epochs
tletext = st.tle(
norad_cat_id=self.tle_norad_id,
orderby="epoch desc",
limit=22,
format="tle",
epoch=f">{epoch_start.datetime},<{epoch_stop.datetime}",
)
) as st:
await st.authenticate()

# Fetch the TLEs between the requested epochs
tletext = await st.tle(
norad_cat_id=self.tle_norad_id,
orderby="epoch desc",
limit=22,
format="tle",
epoch=f">{epoch_start.datetime},<{epoch_stop.datetime}",
)
# Check if we got a return
if tletext == "":
return False
Expand All @@ -234,7 +234,7 @@ def read_tle_space_track(self) -> bool:

return False

def read_tle_concat(self) -> bool:
async def read_tle_concat(self) -> bool:
"""
Read TLEs in the CONCAT MISSION_TLE_ARCHIVE.tle format. This format is
used by the CONCAT to store TLEs for various missions. The format
Expand All @@ -250,7 +250,8 @@ def read_tle_concat(self) -> bool:
return False

# Download TLEs from internet
r = requests.get(self.tle_concat)
async with httpx.AsyncClient() as client:
r = await client.get(self.tle_concat)
try:
# Check for HTTP errors
r.raise_for_status()
Expand Down Expand Up @@ -314,7 +315,7 @@ def tle_out_of_date(self) -> Optional[bool]:
return True
return False

def get(self) -> bool:
async def get(self) -> bool:
"""
Find in the best TLE for a given epoch. This method will first try to
read the TLE from the local database. If that fails, it will try to
Expand Down Expand Up @@ -350,17 +351,17 @@ def get(self) -> bool:
return True

# Fetch TLE from the TLE database
if self.read_tle_db() is True:
if await self.read_tle_db() is True:
if self.tle is not None:
return True

# Next try querying space-track.org for the TLE. This will only work if
# the environment variables SPACE_TRACK_USER and
# SPACE_TRACK_PASS are set, and valid.
if self.read_tle_space_track() is True:
if await self.read_tle_space_track() is True:
# Write the TLE to the database for next time
if self.tle is not None:
self.tle.write()
await self.tle.write()
return True

# Next try try reading the TLE given in the concatenated format at the
Expand All @@ -373,7 +374,7 @@ def get(self) -> bool:
if self.read_tle_concat() is True:
# Write the TLE to the database for next time
if self.tle is not None:
self.tle.write()
await self.tle.write()
return True

# Finally try reading from the web at the URL given by `tle_url`. Note
Expand All @@ -385,7 +386,7 @@ def get(self) -> bool:
if self.read_tle_web() is True:
# Write the TLE to the database for next time
if self.tle is not None:
self.tle.write()
await self.tle.write()
return True

# If we did not find any valid TLEs, then return False
Expand Down
Loading

0 comments on commit f05d065

Please sign in to comment.