Skip to content
Closed
121 changes: 115 additions & 6 deletions Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import Combine

class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
private(set) var metadataDestination: URL

private let chunkSize = 10 * 1024 * 1024 // 10MB

enum DownloadState {
case notStarted
Expand All @@ -29,8 +32,21 @@ class Downloader: NSObject, ObservableObject {

private var urlSession: URLSession? = nil

init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) {
init(
from url: URL,
to destination: URL,
metadataDirURL: URL,
using authToken: String? = nil,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we allowing this to potentially be nil? Is there a case where you can download from HF without providing an auth token?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was made optional by huggingface. you don't need and auth token to download from public repos

inBackground: Bool = false,
resumeSize: Int = 0,
headers: [String: String]? = nil,
expectedSize: Int? = nil,
timeout: TimeInterval = 10,
numRetries: Int = 5
) {
self.destination = destination
let filename = (destination.lastPathComponent as NSString).deletingPathExtension
self.metadataDestination = metadataDirURL.appending(component: "\(filename).metadata")
super.init()
let sessionIdentifier = "swift-transformers.hub.downloader"

Expand All @@ -43,10 +59,18 @@ class Downloader: NSObject, ObservableObject {

self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)

setupDownload(from: url, with: authToken)
setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
}

private func setupDownload(from url: URL, with authToken: String?) {
private func setupDownload(
from url: URL,
with authToken: String?,
resumeSize: Int,
headers: [String: String]?,
expectedSize: Int?,
timeout: TimeInterval,
numRetries: Int
) {
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
Expand All @@ -71,14 +95,99 @@ class Downloader: NSObject, ObservableObject {
}
}
var request = URLRequest(url: url)
var requestHeaders = headers ?? [:]
if let authToken = authToken {
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
requestHeaders["Authorization"] = "Bearer \(authToken)"
}
if resumeSize > 0 {
requestHeaders["Range"] = "bytes=\(resumeSize)-"
}
request.timeoutInterval = timeout
request.allHTTPHeaderFields = requestHeaders

Task {
do {
try await self.downloadWithStreaming(request: request, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)
} catch {
self.downloadState.value = .failed(error)
}
}

self.urlSession?.downloadTask(with: request).resume()
}
}

private func downloadWithStreaming(
request: URLRequest,
resumeSize: Int,
numRetries: Int,
expectedSize: Int?
) async throws {
guard let session = self.urlSession else {
throw DownloadError.unexpectedError
}
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
FileManager.default.createFile(atPath: tempURL.path, contents: nil)
let tempFile = try FileHandle(forWritingTo: tempURL)

defer { tempFile.closeFile() }

let (asyncBytes, response) = try await session.bytes(for: request)
guard let response = response as? HTTPURLResponse else {
throw DownloadError.unexpectedError
}

guard (200..<300).contains(response.statusCode) else {
throw DownloadError.unexpectedError
}

var downloadedSize = resumeSize

var buffer = Data(capacity: chunkSize)
var newNumRetries = numRetries

do {
for try await byte in asyncBytes {
buffer.append(byte)
if buffer.count == chunkSize {
if !buffer.isEmpty { // Filter out keep-alive chunks
try tempFile.write(contentsOf: buffer)
buffer.removeAll(keepingCapacity: true)
downloadedSize += chunkSize
newNumRetries = 5
}
}
}

if !buffer.isEmpty {
try tempFile.write(contentsOf: buffer)
downloadedSize += buffer.count
buffer.removeAll(keepingCapacity: true)
newNumRetries = 5
}
} catch let error as URLError {
if newNumRetries <= 0 {
throw error
}
try await Task.sleep(nanoseconds: 1_000_000_000)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just specify one second? It's less confusing and doesn't require someone to do some quick math to figure out the conversion.


try await downloadWithStreaming(
request: request,
resumeSize: downloadedSize,
numRetries: newNumRetries - 1,
expectedSize: expectedSize
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we would be retrying the request on every error, but not every error is retryable. For instance, if the HF API returns 400, 401 or 403, we shouldn't retry that request because it's never going to succeed. We should only retry the request if the response from HF is in the [500, 599] range, which is a server-side error, which can be transient (hence, retryable).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, generally recommend against implementing this with recursion since it's slightly less readable compared to a simple iterative solution (where we try/catch the error and retry the request until you reach the preset number of retries).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking for error ranges makes sense. I primarily did it this way to follow the design decisions made in the python library, which uses recursion for retries.

}

let actualSize = try tempFile.seekToEnd()
if let expectedSize = expectedSize, expectedSize != actualSize {
throw DownloadError.unexpectedError
}

tempFile.closeFile()
try FileManager.default.moveDownloadedFile(from: tempURL, to: destination)

downloadState.value = .completed(destination)
}

@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
Expand Down
Loading