-
Notifications
You must be signed in to change notification settings - Fork 0
add metadata and resumable download support with tests #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
02d2571
f590932
22b6892
bedfc7a
af26e60
b4e1c49
26707b8
5839d33
9d39cf1
97b6163
fe2f32b
30adb75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,9 @@ import Combine | |
|
|
||
| class Downloader: NSObject, ObservableObject { | ||
| private(set) var destination: URL | ||
| private(set) var metadataDestination: URL | ||
|
|
||
| private let chunkSize = 10 * 1024 * 1024 // 10MB | ||
|
|
||
| enum DownloadState { | ||
| case notStarted | ||
|
|
@@ -29,8 +32,21 @@ class Downloader: NSObject, ObservableObject { | |
|
|
||
| private var urlSession: URLSession? = nil | ||
|
|
||
| init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) { | ||
| init( | ||
| from url: URL, | ||
| to destination: URL, | ||
| metadataDirURL: URL, | ||
| using authToken: String? = nil, | ||
| inBackground: Bool = false, | ||
| resumeSize: Int = 0, | ||
| headers: [String: String]? = nil, | ||
| expectedSize: Int? = nil, | ||
| timeout: TimeInterval = 10, | ||
| numRetries: Int = 5 | ||
| ) { | ||
| self.destination = destination | ||
| let filename = (destination.lastPathComponent as NSString).deletingPathExtension | ||
| self.metadataDestination = metadataDirURL.appending(component: "\(filename).metadata") | ||
| super.init() | ||
| let sessionIdentifier = "swift-transformers.hub.downloader" | ||
|
|
||
|
|
@@ -43,10 +59,18 @@ class Downloader: NSObject, ObservableObject { | |
|
|
||
| self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil) | ||
|
|
||
| setupDownload(from: url, with: authToken) | ||
| setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries) | ||
| } | ||
|
|
||
| private func setupDownload(from url: URL, with authToken: String?) { | ||
| private func setupDownload( | ||
| from url: URL, | ||
| with authToken: String?, | ||
| resumeSize: Int, | ||
| headers: [String: String]?, | ||
| expectedSize: Int?, | ||
| timeout: TimeInterval, | ||
| numRetries: Int | ||
| ) { | ||
| downloadState.value = .downloading(0) | ||
| urlSession?.getAllTasks { tasks in | ||
| // If there's an existing pending background task with the same URL, let it proceed. | ||
|
|
@@ -71,14 +95,99 @@ class Downloader: NSObject, ObservableObject { | |
| } | ||
| } | ||
| var request = URLRequest(url: url) | ||
| var requestHeaders = headers ?? [:] | ||
| if let authToken = authToken { | ||
| request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization") | ||
| requestHeaders["Authorization"] = "Bearer \(authToken)" | ||
| } | ||
| if resumeSize > 0 { | ||
| requestHeaders["Range"] = "bytes=\(resumeSize)-" | ||
| } | ||
| request.timeoutInterval = timeout | ||
| request.allHTTPHeaderFields = requestHeaders | ||
|
|
||
| Task { | ||
| do { | ||
| try await self.downloadWithStreaming(request: request, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize) | ||
| } catch { | ||
| self.downloadState.value = .failed(error) | ||
| } | ||
| } | ||
|
|
||
| self.urlSession?.downloadTask(with: request).resume() | ||
| } | ||
| } | ||
|
|
||
| private func downloadWithStreaming( | ||
| request: URLRequest, | ||
| resumeSize: Int, | ||
| numRetries: Int, | ||
| expectedSize: Int? | ||
| ) async throws { | ||
| guard let session = self.urlSession else { | ||
| throw DownloadError.unexpectedError | ||
| } | ||
| let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString) | ||
| FileManager.default.createFile(atPath: tempURL.path, contents: nil) | ||
| let tempFile = try FileHandle(forWritingTo: tempURL) | ||
|
|
||
| defer { tempFile.closeFile() } | ||
|
|
||
| let (asyncBytes, response) = try await session.bytes(for: request) | ||
| guard let response = response as? HTTPURLResponse else { | ||
| throw DownloadError.unexpectedError | ||
| } | ||
|
|
||
| guard (200..<300).contains(response.statusCode) else { | ||
| throw DownloadError.unexpectedError | ||
| } | ||
|
|
||
| var downloadedSize = resumeSize | ||
|
|
||
| var buffer = Data(capacity: chunkSize) | ||
| var newNumRetries = numRetries | ||
|
|
||
| do { | ||
| for try await byte in asyncBytes { | ||
| buffer.append(byte) | ||
| if buffer.count == chunkSize { | ||
| if !buffer.isEmpty { // Filter out keep-alive chunks | ||
| try tempFile.write(contentsOf: buffer) | ||
| buffer.removeAll(keepingCapacity: true) | ||
| downloadedSize += chunkSize | ||
| newNumRetries = 5 | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if !buffer.isEmpty { | ||
| try tempFile.write(contentsOf: buffer) | ||
| downloadedSize += buffer.count | ||
| buffer.removeAll(keepingCapacity: true) | ||
| newNumRetries = 5 | ||
| } | ||
| } catch let error as URLError { | ||
| if newNumRetries <= 0 { | ||
| throw error | ||
| } | ||
| try await Task.sleep(nanoseconds: 1_000_000_000) | ||
|
||
|
|
||
| try await downloadWithStreaming( | ||
| request: request, | ||
| resumeSize: downloadedSize, | ||
| numRetries: newNumRetries - 1, | ||
| expectedSize: expectedSize | ||
| ) | ||
|
||
| } | ||
|
|
||
| let actualSize = try tempFile.seekToEnd() | ||
| if let expectedSize = expectedSize, expectedSize != actualSize { | ||
| throw DownloadError.unexpectedError | ||
| } | ||
|
|
||
| tempFile.closeFile() | ||
| try FileManager.default.moveDownloadedFile(from: tempURL, to: destination) | ||
|
|
||
| downloadState.value = .completed(destination) | ||
| } | ||
|
|
||
| @discardableResult | ||
| func waitUntilDone() throws -> URL { | ||
| // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we allowing this to potentially be
nil? Is there a case where you can download from HF without providing an auth token?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was made optional by huggingface. you don't need and auth token to download from public repos