From b9d4088d3c464ab8e88e1b475e15b29fc4b31e89 Mon Sep 17 00:00:00 2001 From: Mohamed Afifi Date: Sun, 26 Nov 2023 14:50:07 -0500 Subject: [PATCH] Fix crash when getOnGoingDownloads is called early Sometimes `DownloadManager.getOnGoingDownloads` may be called earlier than `DownloadManager.start`. This result in a crash. This change ensures that getOnGoingDownloads awaits the initialization. --- .../Sources/Features/AsyncInitializer.swift | 39 +++++++++++++++++++ .../DownloadBatchDataController.swift | 32 ++++++++------- .../Sources/Downloader/DownloadManager.swift | 26 +++++++------ .../Tests/DownloadManagerTests.swift | 18 +++++++++ .../BatchDownloaderFake.swift | 26 +++++++++++-- 5 files changed, 111 insertions(+), 30 deletions(-) create mode 100644 Core/Utilities/Sources/Features/AsyncInitializer.swift diff --git a/Core/Utilities/Sources/Features/AsyncInitializer.swift b/Core/Utilities/Sources/Features/AsyncInitializer.swift new file mode 100644 index 00000000..dcf864ef --- /dev/null +++ b/Core/Utilities/Sources/Features/AsyncInitializer.swift @@ -0,0 +1,39 @@ +// +// AsyncInitializer.swift +// +// +// Created by Mohamed Afifi on 2023-11-26. +// + +public struct AsyncInitializer { + // MARK: Lifecycle + + public init() { + var continuation: AsyncStream.Continuation! + let stream = AsyncStream { continuation = $0 } + self.continuation = continuation + self.stream = stream + } + + // MARK: Public + + public private(set) var initialized = false + + public mutating func initialize() { + initialized = true + continuation.finish() + } + + public func awaitInitialization() async { + if initialized { + return + } + // Wait until the stream finishes + for await _ in stream {} + } + + // MARK: Private + + private let continuation: AsyncStream.Continuation + private let stream: AsyncStream +} diff --git a/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift b/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift index 97913e1d..d5419b32 100644 --- a/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift +++ b/Data/BatchDownloader/Sources/Downloader/DownloadBatchDataController.swift @@ -43,31 +43,23 @@ actor DownloadBatchDataController { // MARK: Internal - func bootstrapPersistence() async { - do { - try await attempt(times: 3) { - try await loadBatchesFromPersistence() - } - } catch { - crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.") - } - } - func start(with session: NetworkSession) async { + await bootstrapPersistence() + self.session = session let (_, _, downloadTasks) = await session.tasks() for batch in batches { await batch.associateTasks(downloadTasks) } - loadedInitialRunningTasks = true + initialRunningTasks.initialize() // start pending tasks if needed await startPendingTasksIfNeeded() } - func getOnGoingDownloads() -> [DownloadBatchResponse] { - precondition(loadedInitialRunningTasks) + func getOnGoingDownloads() async -> [DownloadBatchResponse] { + await initialRunningTasks.awaitInitialization() return Array(batches) } @@ -118,7 +110,7 @@ actor DownloadBatchDataController { private var batches: Set = [] - private var loadedInitialRunningTasks = false + private var initialRunningTasks = AsyncInitializer() private var runningTasks: Int { get async { @@ -139,6 +131,16 @@ actor DownloadBatchDataController { } } + private func bootstrapPersistence() async { + do { + try await attempt(times: 3) { + try await loadBatchesFromPersistence() + } + } catch { + crasher.recordError(error, reason: "Failed to retrieve initial download batches from persistence.") + } + } + private func loadBatchesFromPersistence() async throws { let batches = try await persistence.retrieveAll() logger.info("Loading \(batches.count) from persistence") @@ -172,7 +174,7 @@ actor DownloadBatchDataController { } private func startPendingTasksIfNeeded() async { - if !loadedInitialRunningTasks { + if !initialRunningTasks.initialized { return } diff --git a/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift b/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift index 37e364c8..2a0f28ea 100644 --- a/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift +++ b/Data/BatchDownloader/Sources/Downloader/DownloadManager.swift @@ -67,17 +67,7 @@ public final class DownloadManager: Sendable { public func start() async { logger.info("Starting download manager") - let operationQueue = OperationQueue() - operationQueue.name = "com.quran.downloads" - operationQueue.maxConcurrentOperationCount = 1 - - let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch") - operationQueue.underlyingQueue = dispatchQueue - - await dataController.bootstrapPersistence() - - let session = sessionFactory(handler, operationQueue) - self.session = session + let session = createSession() await dataController.start(with: session) logger.info("Download manager start completed") } @@ -101,4 +91,18 @@ public final class DownloadManager: Sendable { private var session: NetworkSession? private let handler: DownloadSessionDelegate private let dataController: DownloadBatchDataController + + private func createSession() -> NetworkSession { + let operationQueue = OperationQueue() + operationQueue.name = "com.quran.downloads" + operationQueue.maxConcurrentOperationCount = 1 + + let dispatchQueue = DispatchQueue(label: "com.quran.downloads.dispatch") + operationQueue.underlyingQueue = dispatchQueue + + let session = sessionFactory(handler, operationQueue) + self.session = session + + return session + } } diff --git a/Data/BatchDownloader/Tests/DownloadManagerTests.swift b/Data/BatchDownloader/Tests/DownloadManagerTests.swift index 9d04b16f..69bab2a7 100644 --- a/Data/BatchDownloader/Tests/DownloadManagerTests.swift +++ b/Data/BatchDownloader/Tests/DownloadManagerTests.swift @@ -67,6 +67,24 @@ final class DownloadManagerTests: XCTestCase { XCTAssertEqual(calls.calls, 1) } + func test_onGoingDownloads_whileStartNotFinished() async throws { + // Load a single batch + let batch = DownloadBatchRequest(requests: [request1]) + _ = try await downloader.download(batch) + + // Deallocate downloader & create new one + downloader = nil + downloader = await BatchDownloaderFake.makeDownloaderDontWaitForSession() + + // Test calling getOnGoingDownloads and start at the same time. + async let startTask: () = await downloader.start() + async let downloadsTask = await downloader.getOnGoingDownloads() + let (downloads, _) = await (downloadsTask, startTask) + + // Verify + XCTAssertEqual(downloads.count, 1) + } + func testLoadingOnGoingDownload() async throws { let emptyDownloads = await downloader.getOnGoingDownloads() XCTAssertEqual(emptyDownloads.count, 0) diff --git a/Data/BatchDownloaderFake/BatchDownloaderFake.swift b/Data/BatchDownloaderFake/BatchDownloaderFake.swift index c65e8ee5..95d8290f 100644 --- a/Data/BatchDownloaderFake/BatchDownloaderFake.swift +++ b/Data/BatchDownloaderFake/BatchDownloaderFake.swift @@ -20,10 +20,7 @@ public enum BatchDownloaderFake { public static let downloadsURL = RelativeFilePath(downloads, isDirectory: true) public static func makeDownloader(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> (DownloadManager, NetworkSessionFake) { - try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true) - let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false) - - let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url) + let persistence = makeDownloadsPersistence() actor SessionActor { var session: NetworkSessionFake! let channel = AsyncChannel() @@ -50,6 +47,20 @@ public enum BatchDownloaderFake { return (downloader, await sessionActor.session) } + public static func makeDownloaderDontWaitForSession(downloads: [SessionTask] = [], fileManager: FileSystem = DefaultFileSystem()) async -> DownloadManager { + let persistence = makeDownloadsPersistence() + let downloader = DownloadManager( + maxSimultaneousDownloads: maxSimultaneousDownloads, + sessionFactory: { delegate, queue in + let session = NetworkSessionFake(queue: queue, delegate: delegate, downloads: downloads) + return session + }, + persistence: persistence, + fileManager: fileManager + ) + return downloader + } + public static func tearDown() { try? FileManager.default.removeItem(at: Self.downloadsURL) } @@ -73,4 +84,11 @@ public enum BatchDownloaderFake { // MARK: Private private static let downloads = "downloads" + + private static func makeDownloadsPersistence() -> GRDBDownloadsPersistence { + try? DefaultFileSystem().createDirectory(at: Self.downloadsURL, withIntermediateDirectories: true) + let downloadsDBPath = Self.downloadsURL.appendingPathComponent("ongoing-downloads.db", isDirectory: false) + let persistence = GRDBDownloadsPersistence(fileURL: downloadsDBPath.url) + return persistence + } }