Skip to content

Commit

Permalink
Merge branch 'main' into mv/level
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Apr 2, 2024
2 parents 5dd9023 + 4ed6f6e commit 4aa37cc
Show file tree
Hide file tree
Showing 19 changed files with 324 additions and 213 deletions.
3 changes: 2 additions & 1 deletion fixbackend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from fixbackend.notification.email.scheduled_email import ScheduledEmailSender
from fixbackend.notification.notification_router import notification_router, unsubscribe_router
from fixbackend.notification.notification_service import NotificationService
from fixbackend.notification.user_notification_repo import UserNotificationSettingsRepositoryImpl
from fixbackend.permissions.role_repository import RoleRepositoryImpl
from fixbackend.permissions.router import roles_router
from fixbackend.sqlalechemy_extensions import EngineMetrics
Expand Down Expand Up @@ -275,7 +276,7 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
analytics_event_sender=analytics_event_sender,
),
)

deps.add(SN.user_notification_settings_repository, UserNotificationSettingsRepositoryImpl(session_maker))
if not cfg.static_assets:
await load_app_from_cdn()
async with deps:
Expand Down
15 changes: 13 additions & 2 deletions fixbackend/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,35 @@ def __init__(


class UserNotificationSettingsRead(BaseModel):
weekly_report: bool = Field(description="Whether to send a weekly report")
inactivity_reminder: bool = Field(description="Whether to send a reminder for open incidents")
weekly_report: bool = Field(description="Whether to receive a weekly report")
inactivity_reminder: bool = Field(description="Whether to receive a reminder for open incidents")
tutorial: bool = Field(description="Whether to receive tutorial emails")

@staticmethod
def from_model(model: UserNotificationSettings) -> "UserNotificationSettingsRead":
return UserNotificationSettingsRead(
weekly_report=model.weekly_report,
inactivity_reminder=model.inactivity_reminder,
tutorial=model.tutorial,
)

def to_model(self, user_id: UserId) -> UserNotificationSettings:
return UserNotificationSettings(
user_id=user_id,
weekly_report=self.weekly_report,
inactivity_reminder=self.inactivity_reminder,
tutorial=self.tutorial,
)


class UserNotificationSettingsWrite(BaseModel):
weekly_report: Optional[bool] = Field(default=None, description="Whether to receive a weekly report")
inactivity_reminder: Optional[bool] = Field(
default=None, description="Whether to receive a reminder for open incidents"
)
tutorial: Optional[bool] = Field(default=None, description="Whether to receive tutorial emails")


class OTPConfig(BaseModel):
secret: str = Field(description="TOTP secret")
recovery_codes: List[str] = Field(description="List of recovery codes")
15 changes: 6 additions & 9 deletions fixbackend/auth/users_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

from fastapi.routing import APIRouter
from fixbackend.auth.depedencies import AuthenticatedUser, fastapi_users
from fixbackend.auth.schemas import UserNotificationSettingsRead, UserRead, UserUpdate
from fixbackend.auth.schemas import UserNotificationSettingsRead, UserRead, UserUpdate, UserNotificationSettingsWrite


from fixbackend.notification.user_notification_repo import UserNotificationSettingsReporitoryDependency
from fixbackend.notification.user_notification_repo import UserNotificationSettingsRepositoryDependency


def users_router() -> APIRouter:
Expand All @@ -29,20 +28,18 @@ def users_router() -> APIRouter:
@router.get("/me/settings/notifications")
async def get_user_notification_settings(
user: AuthenticatedUser,
user_notification_repo: UserNotificationSettingsReporitoryDependency,
user_notification_repo: UserNotificationSettingsRepositoryDependency,
) -> UserNotificationSettingsRead:
settings = await user_notification_repo.get_notification_settings(user.id)
return UserNotificationSettingsRead.from_model(settings)

@router.put("/me/settings/notifications")
async def update_user_notification_settings(
user: AuthenticatedUser,
notification_settings: UserNotificationSettingsRead,
user_notification_repo: UserNotificationSettingsReporitoryDependency,
notification_settings: UserNotificationSettingsWrite,
user_notification_repo: UserNotificationSettingsRepositoryDependency,
) -> UserNotificationSettingsRead:
updated = await user_notification_repo.update_notification_settings(
user.id, notification_settings.to_model(user.id)
)
updated = await user_notification_repo.update_notification_settings(user.id, **notification_settings.dict())
return UserNotificationSettingsRead.from_model(updated)

return router
2 changes: 1 addition & 1 deletion fixbackend/billing_information/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def get_payment_methods(self, workspace: Workspace, user_id: UserId) -> Wo
current: PaymentMethod = PaymentMethods.NoPaymentMethod()

payment_methods: List[PaymentMethod] = []
if workspace.product_tier == ProductTier.Free:
if workspace.product_tier == ProductTier.Free or workspace.product_tier == ProductTier.Trial:
payment_methods.append(PaymentMethods.NoPaymentMethod())

async def get_current_subscription() -> Optional[AwsMarketplaceSubscription]:
Expand Down
229 changes: 117 additions & 112 deletions fixbackend/cloud_accounts/service_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,132 +324,137 @@ async def handle_stack_deleted(msg: Json) -> Optional[CloudAccount]:
return None

async def process_domain_event(self, message: Json, context: MessageContext) -> None:
log.info(f"Received domain event: {message}")

async def send_pub_sub_message(
e: Union[AwsAccountDegraded, AwsAccountDiscovered, AwsAccountDeleted, AwsAccountConfigured]
) -> None:
msg = e.to_json()
msg.pop("tenant_id", None)
await self.pubsub_publisher.publish(kind=e.kind, message=msg, channel=f"tenant-events::{e.tenant_id}")

match context.kind:
case TenantAccountsCollected.kind:
event = TenantAccountsCollected.from_json(message)

accounts = await self.cloud_account_repository.list(list(event.cloud_accounts.keys()))
collected_accounts = [account for account in accounts if account.id in event.cloud_accounts]
first_workspace_collect = all(account.last_scan_started_at is None for account in accounts)
first_account_collect = any(account.last_scan_started_at is None for account in collected_accounts)

set_workspace_id(event.tenant_id)
for account_id, account in event.cloud_accounts.items():
set_fix_cloud_account_id(account_id)
set_cloud_account_id(account.account_id)

def compute_failed_scan_count(acc: CloudAccount) -> int:
if account.scanned_resources < 50:
return acc.failed_scan_count + 1
else:
return 0

updated = await self.cloud_account_repository.update(
account_id,
lambda acc: evolve(
acc,
last_scan_duration_seconds=account.duration_seconds,
last_scan_resources_scanned=account.scanned_resources,
last_scan_started_at=account.started_at,
next_scan=event.next_run,
failed_scan_count=compute_failed_scan_count(acc),
),
async with asyncio.timeout(10):
match context.kind:
case TenantAccountsCollected.kind:
event = TenantAccountsCollected.from_json(message)

accounts = await self.cloud_account_repository.list(list(event.cloud_accounts.keys()))
collected_accounts = [account for account in accounts if account.id in event.cloud_accounts]
first_workspace_collect = all(account.last_scan_started_at is None for account in accounts)
first_account_collect = any(account.last_scan_started_at is None for account in collected_accounts)

set_workspace_id(event.tenant_id)
for account_id, account in event.cloud_accounts.items():
set_fix_cloud_account_id(account_id)
set_cloud_account_id(account.account_id)

def compute_failed_scan_count(acc: CloudAccount) -> int:
if account.scanned_resources < 50:
return acc.failed_scan_count + 1
else:
return 0

updated = await self.cloud_account_repository.update(
account_id,
lambda acc: evolve(
acc,
last_scan_duration_seconds=account.duration_seconds,
last_scan_resources_scanned=account.scanned_resources,
last_scan_started_at=account.started_at,
next_scan=event.next_run,
failed_scan_count=compute_failed_scan_count(acc),
),
)

if updated.failed_scan_count > 3:
await self.__degrade_account(updated.id, "Too many consecutive failed scans")

user_id = await self.analytics_event_sender.user_id_from_workspace(event.tenant_id)
if first_workspace_collect:
await self.analytics_event_sender.send(
AEFirstWorkspaceCollectFinished(user_id, event.tenant_id)
)
# inform workspace users about the first successful collect
await self.notification_service.send_message_to_workspace(
workspace_id=event.tenant_id, message=email.SecurityScanFinished()
)
if first_account_collect:
await self.analytics_event_sender.send(AEFirstAccountCollectFinished(user_id, event.tenant_id))

await self.analytics_event_sender.send(
AEWorkspaceCollectFinished(
user_id,
event.tenant_id,
len(collected_accounts),
sum(a.scanned_resources for a in event.cloud_accounts.values()),
)
)

if updated.failed_scan_count > 3:
await self.__degrade_account(updated.id, "Too many consecutive failed scans")
case AwsAccountDiscovered.kind:
discovered_event = AwsAccountDiscovered.from_json(message)
set_cloud_account_id(discovered_event.aws_account_id)
set_fix_cloud_account_id(discovered_event.cloud_account_id)
set_workspace_id(discovered_event.tenant_id)
await self.process_discovered_event(discovered_event)
await send_pub_sub_message(discovered_event)

case AwsAccountConfigured.kind:
configured_event = AwsAccountConfigured.from_json(message)
await send_pub_sub_message(configured_event)

user_id = await self.analytics_event_sender.user_id_from_workspace(event.tenant_id)
if first_workspace_collect:
await self.analytics_event_sender.send(AEFirstWorkspaceCollectFinished(user_id, event.tenant_id))
# inform workspace users about the first successful collect
case AwsAccountDeleted.kind:
deleted_event = AwsAccountDeleted.from_json(message)
await send_pub_sub_message(deleted_event)

case AwsAccountDegraded.kind:
degraded_event = AwsAccountDegraded.from_json(message)
await self.notification_service.send_message_to_workspace(
workspace_id=event.tenant_id, message=email.SecurityScanFinished()
)
if first_account_collect:
await self.analytics_event_sender.send(AEFirstAccountCollectFinished(user_id, event.tenant_id))

await self.analytics_event_sender.send(
AEWorkspaceCollectFinished(
user_id,
event.tenant_id,
len(collected_accounts),
sum(a.scanned_resources for a in event.cloud_accounts.values()),
workspace_id=degraded_event.tenant_id,
message=email.AccountDegraded(
cloud_account_id=degraded_event.aws_account_id,
tenant_id=degraded_event.tenant_id,
account_name=degraded_event.aws_account_name,
),
)
)

case AwsAccountDiscovered.kind:
discovered_event = AwsAccountDiscovered.from_json(message)
set_cloud_account_id(discovered_event.aws_account_id)
set_fix_cloud_account_id(discovered_event.cloud_account_id)
set_workspace_id(discovered_event.tenant_id)
await self.process_discovered_event(discovered_event)
await send_pub_sub_message(discovered_event)

case AwsAccountConfigured.kind:
configured_event = AwsAccountConfigured.from_json(message)
await send_pub_sub_message(configured_event)

case AwsAccountDeleted.kind:
deleted_event = AwsAccountDeleted.from_json(message)
await send_pub_sub_message(deleted_event)

case AwsAccountDegraded.kind:
degraded_event = AwsAccountDegraded.from_json(message)
await self.notification_service.send_message_to_workspace(
workspace_id=degraded_event.tenant_id,
message=email.AccountDegraded(
cloud_account_id=degraded_event.aws_account_id,
tenant_id=degraded_event.tenant_id,
account_name=degraded_event.aws_account_name,
),
)
await send_pub_sub_message(degraded_event)

case ProductTierChanged.kind:
ptc_evt = ProductTierChanged.from_json(message)
new_account_limit = ProductTierSettings[ptc_evt.product_tier].account_limit or math.inf
old_account_limit = ProductTierSettings[ptc_evt.previous_tier].account_limit or math.inf
if new_account_limit < old_account_limit:
# we should not have infinity here
new_account_limit = round(new_account_limit)
# tier changed, time to delete accounts
all_accounts = await self.list_accounts(ptc_evt.workspace_id)
# keep the last new_account_limit accounts
to_delete = all_accounts[:-new_account_limit]
# delete them all in parallel
await send_pub_sub_message(degraded_event)

case ProductTierChanged.kind:
ptc_evt = ProductTierChanged.from_json(message)
new_account_limit = ProductTierSettings[ptc_evt.product_tier].account_limit or math.inf
old_account_limit = ProductTierSettings[ptc_evt.previous_tier].account_limit or math.inf
if new_account_limit < old_account_limit:
# we should not have infinity here
new_account_limit = round(new_account_limit)
# tier changed, time to delete accounts
all_accounts = await self.list_accounts(ptc_evt.workspace_id)
# keep the last new_account_limit accounts
to_delete = all_accounts[:-new_account_limit]
# delete them all in parallel
async with asyncio.TaskGroup() as tg:
for cloud_account in to_delete:
tg.create_task(
self.delete_cloud_account(ptc_evt.user_id, cloud_account.id, ptc_evt.workspace_id)
)

case AwsMarketplaceSubscriptionCancelled.kind:
evt = AwsMarketplaceSubscriptionCancelled.from_json(message)
workspaces = await self.workspace_repository.list_workspaces_by_subscription_id(evt.subscription_id)
async with asyncio.TaskGroup() as tg:
for cloud_account in to_delete:
tg.create_task(
self.delete_cloud_account(ptc_evt.user_id, cloud_account.id, ptc_evt.workspace_id)
)

case AwsMarketplaceSubscriptionCancelled.kind:
evt = AwsMarketplaceSubscriptionCancelled.from_json(message)
workspaces = await self.workspace_repository.list_workspaces_by_subscription_id(evt.subscription_id)
async with asyncio.TaskGroup() as tg:
for ws in workspaces:
# first move the tier to free
await self.workspace_repository.update_payment_on_hold(ws.id, utc())
# second remove the subscription from the workspace
await self.workspace_repository.update_subscription(ws.id, None)
# third disable all accounts
account_limit = Free.account_limit or 1
all_accounts = await self.list_accounts(ws.id)
# keep the last account_limit accounts
to_disable = all_accounts[:-account_limit]
for cloud_account in to_disable:
tg.create_task(self.update_cloud_account_enabled(ws.id, cloud_account.id, False))
for ws in workspaces:
# first move the tier to free
await self.workspace_repository.update_payment_on_hold(ws.id, utc())
# second remove the subscription from the workspace
await self.workspace_repository.update_subscription(ws.id, None)
# third disable all accounts
account_limit = Free.account_limit or 1
all_accounts = await self.list_accounts(ws.id)
# keep the last account_limit accounts
to_disable = all_accounts[:-account_limit]
for cloud_account in to_disable:
tg.create_task(self.update_cloud_account_enabled(ws.id, cloud_account.id, False))

case _:
pass # ignore other domain events
case _:
pass # ignore other domain events

async def process_discovered_event(self, discovered: AwsAccountDiscovered) -> None:
account = await self.cloud_account_repository.get(discovered.cloud_account_id)
Expand Down
1 change: 1 addition & 0 deletions fixbackend/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ServiceNames:
invitation_repository = "invitation_repository"
analytics_event_sender = "analytics_event_sender"
notification_service = "notification_service"
user_notification_settings_repository = "user_notification_settings_repository"
email_on_signup_consumer = "email_on_signup_consumer"
billing_entry_service = "billing_entry_services"
role_repository = "role_repository"
Expand Down
1 change: 1 addition & 0 deletions fixbackend/ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TaskId = NewType("TaskId", str)
BenchmarkName = NewType("BenchmarkName", str)
UserRoleId = NewType("UserRoleId", UUID)
Email = NewType("Email", str)


class NotificationProvider(StrEnum):
Expand Down
Loading

0 comments on commit 4aa37cc

Please sign in to comment.