diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 8e0513084..df1f5e440 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -6,6 +6,8 @@ ### New Features and Improvements +- Upload of big files (> 5Gb) to UC Volumes using multipart chunking ([#621](https://github.com/databricks/databricks-sdk-go/pull/1621)). + ### Bug Fixes ### Documentation @@ -29,4 +31,4 @@ * Add `GpuXlarge` enum value for [serving.ServedModelInputWorkloadType](https://pkg.go.dev/github.com/databricks/databricks-sdk-go/service/serving#ServedModelInputWorkloadType). * Add `GpuXlarge` enum value for [serving.ServingModelWorkloadType](https://pkg.go.dev/github.com/databricks/databricks-sdk-go/service/serving#ServingModelWorkloadType). * [Breaking] Change `ListFeatures` method for [w.FeatureEngineering](https://pkg.go.dev/github.com/databricks/databricks-sdk-go/service/ml#FeatureEngineeringAPI) workspace-level service with new required argument order. -* [Breaking] Remove `UnspecifiedResourceName` field for [postgres.RequestedResource](https://pkg.go.dev/github.com/databricks/databricks-sdk-go/service/postgres#RequestedResource). \ No newline at end of file +* [Breaking] Remove `UnspecifiedResourceName` field for [postgres.RequestedResource](https://pkg.go.dev/github.com/databricks/databricks-sdk-go/service/postgres#RequestedResource). diff --git a/experimental/mocks/service/files/mock_files_upload.go b/experimental/mocks/service/files/mock_files_upload.go new file mode 100644 index 000000000..2ed7b459a --- /dev/null +++ b/experimental/mocks/service/files/mock_files_upload.go @@ -0,0 +1,55 @@ +// Hand-written mock stubs for filesAPIUploadUtilities methods. +// These will be replaced by mockery-generated code on the next `make codegen`. + +package files + +import ( + "context" + "io" + + files "github.com/databricks/databricks-sdk-go/service/files" +) + +// UploadWithChunking provides a mock function for the FilesInterface. +func (_m *MockFilesInterface) UploadWithChunking(ctx context.Context, filePath string, content io.ReadSeeker, contentLength int64, opts ...files.UploadOption) error { + _va := make([]any, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []any + _ca = append(_ca, ctx, filePath, content, contentLength) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for UploadWithChunking") + } + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, io.ReadSeeker, int64, ...files.UploadOption) error); ok { + r0 = rf(ctx, filePath, content, contentLength, opts...) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// UploadFromFile provides a mock function for the FilesInterface. +func (_m *MockFilesInterface) UploadFromFile(ctx context.Context, filePath string, sourcePath string, opts ...files.UploadOption) error { + _va := make([]any, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []any + _ca = append(_ca, ctx, filePath, sourcePath) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + if len(ret) == 0 { + panic("no return value specified for UploadFromFile") + } + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, ...files.UploadOption) error); ok { + r0 = rf(ctx, filePath, sourcePath, opts...) + } else { + r0 = ret.Error(0) + } + return r0 +} diff --git a/service/files/api.go b/service/files/api.go index 432828c39..4b4c6453d 100755 --- a/service/files/api.go +++ b/service/files/api.go @@ -217,6 +217,7 @@ func (a *DbfsAPI) MkdirsByPath(ctx context.Context, path string) error { } type FilesInterface interface { + filesAPIUploadUtilities // Creates an empty directory. If necessary, also creates any parent directories // of the new, empty directory (like the shell command `mkdir -p`). If called on diff --git a/service/files/ext_upload.go b/service/files/ext_upload.go new file mode 100644 index 000000000..0f2dbfba0 --- /dev/null +++ b/service/files/ext_upload.go @@ -0,0 +1,628 @@ +package files + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net/http" + "os" + "sort" + "strings" + "sync" + "time" + + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/databricks/databricks-sdk-go/logger" + "github.com/databricks/databricks-sdk-go/useragent" +) + +const ( + // minMultipartUploadSize is the minimum file size (in bytes) to trigger multipart upload. + minMultipartUploadSize = 50 * 1024 * 1024 // 50 MiB + + // defaultPartSize is the default part size for multipart uploads. + defaultPartSize = 10 * 1024 * 1024 // 10 MiB + + // maxPartSize is the maximum part size for multipart uploads (Azure limit). + maxPartSize = 4 * 1024 * 1024 * 1024 // 4 GiB + + // defaultParallelism is the default number of concurrent upload workers. + defaultParallelism = 10 + + // maxPartsTarget is the target maximum number of parts for an upload. + maxPartsTarget = 100 + + // urlExpirationDuration is how long presigned URLs are valid. + urlExpirationDuration = 1 * time.Hour + + // maxURLExpirationRetries is the maximum number of retries for URL expiration. + maxURLExpirationRetries = 3 +) + +// partSizeOptions lists the candidate part sizes in ascending order. +var partSizeOptions = []int64{ + 10 * 1024 * 1024, // 10 MiB + 20 * 1024 * 1024, // 20 MiB + 50 * 1024 * 1024, // 50 MiB + 100 * 1024 * 1024, // 100 MiB + 200 * 1024 * 1024, // 200 MiB + 500 * 1024 * 1024, // 500 MiB + 1 * 1024 * 1024 * 1024, // 1 GiB + 2 * 1024 * 1024 * 1024, // 2 GiB + 4 * 1024 * 1024 * 1024, // 4 GiB +} + +// filesAPIUploadUtilities defines the hand-written upload extension methods +// for FilesAPI. This interface is embedded in FilesInterface (in api.go) so +// that the methods are accessible through WorkspaceClient.Files. +type filesAPIUploadUtilities interface { + // UploadWithChunking uploads a file to a Unity Catalog volume path. For + // files smaller than 50 MiB it uses the standard single-shot Upload. For + // larger files it uses the multipart chunked upload protocol with parallel + // part uploads. If the first chunk of a multipart upload fails, it falls + // back to single-shot upload automatically. + // + // filePath is the absolute remote path (e.g. /Volumes/catalog/schema/volume/file.bin). + // content must be an io.ReadSeeker so the SDK can read parts for parallel upload. + // contentLength is the total file size in bytes. + UploadWithChunking(ctx context.Context, filePath string, content io.ReadSeeker, contentLength int64, opts ...UploadOption) error + + // UploadFromFile uploads a local file to a Unity Catalog volume path. It + // automatically detects the file size and uses multipart chunked upload for + // files larger than 50 MiB. + UploadFromFile(ctx context.Context, filePath string, sourcePath string, opts ...UploadOption) error +} + +// UploadConfig holds configuration for a multipart upload. +type UploadConfig struct { + PartSize int64 + Parallelism int + Overwrite bool +} + +// UploadOption is a functional option for configuring a multipart upload. +type UploadOption func(*UploadConfig) + +// WithPartSize sets the part size for a multipart upload. +func WithPartSize(partSize int64) UploadOption { + return func(c *UploadConfig) { + c.PartSize = partSize + } +} + +// WithParallelism sets the number of concurrent upload workers. +func WithParallelism(parallelism int) UploadOption { + return func(c *UploadConfig) { + c.Parallelism = parallelism + } +} + +// WithOverwrite sets whether to overwrite an existing file. +func WithOverwrite(overwrite bool) UploadOption { + return func(c *UploadConfig) { + c.Overwrite = overwrite + } +} + +// initiateUploadResponse is the response from initiating a multipart upload. +type initiateUploadResponse struct { + MultipartUpload *multipartUploadSession `json:"multipart_upload,omitempty"` + ResumableUpload *resumableUploadSession `json:"resumable_upload,omitempty"` +} + +// multipartUploadSession holds the state for a multipart upload session. +type multipartUploadSession struct { + SessionToken string `json:"session_token"` +} + +// resumableUploadSession holds the state for a resumable upload session. +type resumableUploadSession struct { + SessionToken string `json:"session_token"` +} + +// presignedURL represents a presigned URL for uploading a part. +type presignedURL struct { + URL string `json:"url"` + PartNumber int `json:"part_number"` + Headers []presignedHeader `json:"headers"` +} + +// presignedHeader is a header to include when uploading to a presigned URL. +type presignedHeader struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// createUploadPartURLsRequest is the request to create presigned URLs for upload parts. +type createUploadPartURLsRequest struct { + Path string `json:"path"` + SessionToken string `json:"session_token"` + StartPartNumber int `json:"start_part_number"` + Count int `json:"count"` + ExpireTime string `json:"expire_time"` +} + +// createUploadPartURLsResponse is the response containing presigned URLs for upload parts. +type createUploadPartURLsResponse struct { + UploadPartURLs []presignedURL `json:"upload_part_urls"` +} + +// completeUploadPart represents a completed upload part. +type completeUploadPart struct { + PartNumber int `json:"part_number"` + ETag string `json:"etag"` +} + +// completeUploadRequest is the request to complete a multipart upload. +type completeUploadRequest struct { + Parts []completeUploadPart `json:"parts"` +} + +// optimizePartSize selects the best part size and batch size for a multipart upload. +// +// If explicitPartSize > 0, it is used directly and the batch size is computed +// as max(1, ceil(sqrt(numParts))). +// +// If contentLength <= 0 (unknown), defaultPartSize and batch size 1 are returned. +// +// Otherwise, the smallest part size from partSizeOptions where the number of +// parts is <= maxPartsTarget is selected. If no option satisfies the constraint, +// maxPartSize is used as a fallback. +func optimizePartSize(contentLength int64, explicitPartSize int64) (int64, int) { + if explicitPartSize > 0 { + numParts := int(math.Ceil(float64(contentLength) / float64(explicitPartSize))) + if numParts < 1 { + numParts = 1 + } + batchSize := int(math.Ceil(math.Sqrt(float64(numParts)))) + if batchSize < 1 { + batchSize = 1 + } + return explicitPartSize, batchSize + } + + if contentLength <= 0 { + return defaultPartSize, 1 + } + + for _, partSize := range partSizeOptions { + numParts := int(math.Ceil(float64(contentLength) / float64(partSize))) + if numParts <= maxPartsTarget { + batchSize := int(math.Ceil(math.Sqrt(float64(numParts)))) + if batchSize < 1 { + batchSize = 1 + } + return partSize, batchSize + } + } + + // Fallback to maxPartSize + numParts := int(math.Ceil(float64(contentLength) / float64(maxPartSize))) + if numParts < 1 { + numParts = 1 + } + batchSize := int(math.Ceil(math.Sqrt(float64(numParts)))) + if batchSize < 1 { + batchSize = 1 + } + return maxPartSize, batchSize +} + +// uploadURLExpireTime returns the expiration time for a presigned URL as an RFC3339 string. +func uploadURLExpireTime() string { + return time.Now().Add(urlExpirationDuration).UTC().Format(time.RFC3339) +} + +// initiateMultipartUpload starts a multipart upload session for the given file path. +func (a *FilesAPI) initiateMultipartUpload(ctx context.Context, filePath string, overwrite bool) (*initiateUploadResponse, error) { + var resp initiateUploadResponse + apiPath := fmt.Sprintf("/api/2.0/fs/files%s", httpclient.EncodeMultiSegmentPathParameter(filePath)) + headers := make(map[string]string) + headers["Accept"] = "application/json" + queryParams := make(map[string]any) + queryParams["action"] = "initiate-upload" + if overwrite { + queryParams["overwrite"] = "true" + } + err := a.filesImpl.client.Do(ctx, http.MethodPost, apiPath, headers, queryParams, nil, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +// getUploadPartURLs fetches presigned URLs for uploading parts. +func (a *FilesAPI) getUploadPartURLs(ctx context.Context, filePath, sessionToken string, startPartNumber, count int) ([]presignedURL, error) { + var resp createUploadPartURLsResponse + headers := make(map[string]string) + headers["Content-Type"] = "application/json" + headers["Accept"] = "application/json" + queryParams := make(map[string]any) + req := createUploadPartURLsRequest{ + Path: filePath, + SessionToken: sessionToken, + StartPartNumber: startPartNumber, + Count: count, + ExpireTime: uploadURLExpireTime(), + } + err := a.filesImpl.client.Do(ctx, http.MethodPost, "/api/2.0/fs/create-upload-part-urls", headers, queryParams, req, &resp) + if err != nil { + return nil, err + } + return resp.UploadPartURLs, nil +} + +// isURLExpiredResponse returns true if the HTTP response indicates the presigned URL has expired. +func isURLExpiredResponse(statusCode int, body []byte) bool { + if statusCode != http.StatusForbidden { + return false + } + s := string(body) + return strings.Contains(s, "AccessDenied") || strings.Contains(s, "AuthenticationFailed") +} + +// uploadOnePart uploads a single part to a presigned URL and returns the ETag. +func (a *FilesAPI) uploadOnePart(ctx context.Context, presigned presignedURL, partData io.ReadSeeker, contentLength int64) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, presigned.URL, partData) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.ContentLength = contentLength + req.Header.Set("Content-Type", "application/octet-stream") + for _, h := range presigned.Headers { + req.Header.Set(h.Name, h.Value) + } + + // Use a client without a total timeout — context cancellation handles + // overall deadline, and large parts may legitimately take a long time. + httpClient := &http.Client{} + resp, err := httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("upload part %d failed: %w", presigned.PartNumber, err) + } + + if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated { + etag := resp.Header.Get("ETag") + resp.Body.Close() + return etag, nil + } + + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return "", &partUploadError{ + partNumber: presigned.PartNumber, + statusCode: resp.StatusCode, + body: body, + } +} + +// partUploadError represents a failed part upload with status and body. +type partUploadError struct { + partNumber int + statusCode int + body []byte +} + +func (e *partUploadError) Error() string { + return fmt.Sprintf("upload part %d failed with status %d: %s", e.partNumber, e.statusCode, string(e.body)) +} + +func (e *partUploadError) isExpiredURL() bool { + return isURLExpiredResponse(e.statusCode, e.body) +} + +// uploadOnePartWithRetry uploads a part, re-fetching the presigned URL if it expires. +func (a *FilesAPI) uploadOnePartWithRetry(ctx context.Context, filePath, sessionToken string, partNumber int, partData io.ReadSeeker, contentLength int64) (string, error) { + for attempt := 0; attempt <= maxURLExpirationRetries; attempt++ { + // Rewind reader for retries. + if attempt > 0 { + if _, err := partData.Seek(0, io.SeekStart); err != nil { + return "", fmt.Errorf("failed to rewind part data: %w", err) + } + } + + // Fetch a (fresh) presigned URL. + urls, err := a.getUploadPartURLs(ctx, filePath, sessionToken, partNumber, 1) + if err != nil { + return "", fmt.Errorf("failed to get upload URL for part %d: %w", partNumber, err) + } + if len(urls) == 0 { + return "", fmt.Errorf("no upload URL returned for part %d", partNumber) + } + + etag, err := a.uploadOnePart(ctx, urls[0], partData, contentLength) + if err != nil { + var pErr *partUploadError + if errors.As(err, &pErr) && pErr.isExpiredURL() && attempt < maxURLExpirationRetries { + logger.Debugf(ctx, "presigned URL expired for part %d (attempt %d/%d), fetching new URL", + partNumber, attempt+1, maxURLExpirationRetries) + continue + } + return "", err + } + return etag, nil + } + + return "", fmt.Errorf("upload part %d: presigned URL expired after %d retries", partNumber, maxURLExpirationRetries) +} + +// completeMultipartUpload finalizes a multipart upload with the given ETags. +func (a *FilesAPI) completeMultipartUpload(ctx context.Context, filePath, sessionToken string, etags map[int]string) error { + apiPath := fmt.Sprintf("/api/2.0/fs/files%s", httpclient.EncodeMultiSegmentPathParameter(filePath)) + headers := make(map[string]string) + headers["Content-Type"] = "application/json" + headers["Accept"] = "application/json" + queryParams := make(map[string]any) + queryParams["action"] = "complete-upload" + queryParams["upload_type"] = "multipart" + queryParams["session_token"] = sessionToken + + parts := make([]completeUploadPart, 0, len(etags)) + for partNum, etag := range etags { + parts = append(parts, completeUploadPart{ + PartNumber: partNum, + ETag: etag, + }) + } + sort.Slice(parts, func(i, j int) bool { + return parts[i].PartNumber < parts[j].PartNumber + }) + + req := completeUploadRequest{Parts: parts} + return a.filesImpl.client.Do(ctx, http.MethodPost, apiPath, headers, queryParams, req, nil) +} + +// errFallbackToSingleShot signals that multipart upload failed on the first +// chunk and the caller should retry with single-shot upload. +type errFallbackToSingleShot struct { + reason error +} + +func (e *errFallbackToSingleShot) Error() string { + return fmt.Sprintf("falling back to single-shot upload: %v", e.reason) +} + +// uploadMultipart orchestrates a full multipart upload. If the first part +// fails (e.g., 403 from Azure firewall), it returns errFallbackToSingleShot +// so the caller can retry with the standard Upload method. +func (a *FilesAPI) uploadMultipart(ctx context.Context, filePath string, content io.ReadSeeker, contentLength int64, cfg *UploadConfig) error { + if contentLength <= 0 { + return fmt.Errorf("contentLength must be positive, got %d", contentLength) + } + + // Phase 1: Initiate + initResp, err := a.initiateMultipartUpload(ctx, filePath, cfg.Overwrite) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + if initResp.MultipartUpload == nil { + return fmt.Errorf("multipart upload not supported for this path (GCP is not supported)") + } + sessionToken := initResp.MultipartUpload.SessionToken + logger.Debugf(ctx, "initiated multipart upload with session token") + + partSize := cfg.PartSize + parallelism := cfg.Parallelism + if parallelism < 1 { + parallelism = defaultParallelism + } + + // Phase 2a: Upload first part synchronously to detect early failures. + // If the first part fails (e.g. Azure firewall 403), signal fallback to + // single-shot upload so the caller can retry without multipart. + etags := make(map[int]string) + var totalBytes int64 + + firstBuf := make([]byte, partSize) + firstN, firstReadErr := io.ReadFull(content, firstBuf) + if firstN == 0 && firstReadErr != nil { + // Empty content — nothing to upload. Complete with zero parts. + return a.completeMultipartUpload(ctx, filePath, sessionToken, etags) + } + firstBuf = firstBuf[:firstN] + totalBytes += int64(firstN) + + firstEtag, err := a.uploadOnePartWithRetry(ctx, filePath, sessionToken, 1, bytes.NewReader(firstBuf), int64(firstN)) + if err != nil { + a.abortMultipartUpload(ctx, filePath, sessionToken) + return &errFallbackToSingleShot{reason: err} + } + etags[1] = firstEtag + logger.Debugf(ctx, "uploaded part 1 (first chunk validated)") + + // Phase 2b: Upload remaining parts in parallel. + if firstReadErr == nil { + sem := make(chan struct{}, parallelism) + var mu sync.Mutex + var uploadErr error + + partNumber := 2 + for { + if err := ctx.Err(); err != nil { + break + } + mu.Lock() + if uploadErr != nil { + mu.Unlock() + break + } + mu.Unlock() + + buf := make([]byte, partSize) + n, readErr := io.ReadFull(content, buf) + if n == 0 && readErr != nil { + break + } + buf = buf[:n] + partDataLen := int64(n) + totalBytes += partDataLen + + currentPartNumber := partNumber + partNumber++ + + sem <- struct{}{} + go func(pn int, data []byte) { + defer func() { <-sem }() + + mu.Lock() + if uploadErr != nil { + mu.Unlock() + return + } + mu.Unlock() + if ctx.Err() != nil { + return + } + + etag, err := a.uploadOnePartWithRetry(ctx, filePath, sessionToken, pn, bytes.NewReader(data), partDataLen) + if err != nil { + mu.Lock() + if uploadErr == nil { + uploadErr = err + } + mu.Unlock() + return + } + + mu.Lock() + etags[pn] = etag + mu.Unlock() + logger.Debugf(ctx, "uploaded part %d", pn) + }(currentPartNumber, buf) + + if readErr != nil { + break + } + } + + // Wait for all goroutines to finish. + for i := 0; i < parallelism; i++ { + sem <- struct{}{} + } + + if uploadErr != nil { + a.abortMultipartUpload(ctx, filePath, sessionToken) + return uploadErr + } + } + + // Verify total bytes read matches the declared content length. + if totalBytes != contentLength { + a.abortMultipartUpload(ctx, filePath, sessionToken) + return fmt.Errorf("content length mismatch: declared %d bytes but read %d bytes", contentLength, totalBytes) + } + + // Phase 3: Complete + logger.Debugf(ctx, "completing multipart upload with %d parts", len(etags)) + return a.completeMultipartUpload(ctx, filePath, sessionToken, etags) +} + +// abortMultipartUpload attempts to abort an in-progress multipart upload. +// This is a best-effort cleanup; errors are logged but not returned. +func (a *FilesAPI) abortMultipartUpload(ctx context.Context, filePath, sessionToken string) { + apiPath := "/api/2.0/fs/create-abort-upload-url" + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "application/json", + } + body := map[string]string{ + "path": filePath, + "session_token": sessionToken, + "expire_time": uploadURLExpireTime(), + } + + var resp struct { + AbortUploadURL struct { + URL string `json:"url"` + Headers []presignedHeader `json:"headers,omitempty"` + } `json:"abort_upload_url"` + } + err := a.filesImpl.client.Do(ctx, http.MethodPost, apiPath, headers, nil, body, &resp) + if err != nil { + logger.Debugf(ctx, "failed to get abort URL: %v", err) + return + } + if resp.AbortUploadURL.URL == "" { + logger.Debugf(ctx, "no abort URL returned") + return + } + + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, resp.AbortUploadURL.URL, nil) + if err != nil { + logger.Debugf(ctx, "failed to create abort request: %v", err) + return + } + for _, h := range resp.AbortUploadURL.Headers { + req.Header.Set(h.Name, h.Value) + } + + httpClient := &http.Client{Timeout: 30 * time.Second} + abortResp, err := httpClient.Do(req) + if err != nil { + logger.Debugf(ctx, "failed to abort multipart upload: %v", err) + return + } + abortResp.Body.Close() + logger.Debugf(ctx, "aborted multipart upload (status %d)", abortResp.StatusCode) +} + +// UploadWithChunking uploads a file to the given path, automatically choosing +// between single-shot and multipart upload based on the content length. +// For files smaller than 50 MiB, a single PUT request is used. For larger files, +// a multipart upload is performed with configurable part size and parallelism. +func (a *FilesAPI) UploadWithChunking(ctx context.Context, filePath string, content io.ReadSeeker, contentLength int64, opts ...UploadOption) error { + ctx = useragent.InContext(ctx, "sdk-feature", "multipart-upload") + + cfg := &UploadConfig{ + Parallelism: defaultParallelism, + } + for _, opt := range opts { + opt(cfg) + } + + // Auto-select part size if not explicitly set + cfg.PartSize, _ = optimizePartSize(contentLength, cfg.PartSize) + + if contentLength < minMultipartUploadSize { + return a.Upload(ctx, UploadRequest{ + FilePath: filePath, + Contents: io.NopCloser(content), + Overwrite: cfg.Overwrite, + }) + } + + err := a.uploadMultipart(ctx, filePath, content, contentLength, cfg) + var fallback *errFallbackToSingleShot + if errors.As(err, &fallback) { + logger.Debugf(ctx, "multipart first-chunk failed (%v), falling back to single-shot upload", fallback.reason) + if _, seekErr := content.Seek(0, io.SeekStart); seekErr != nil { + return fmt.Errorf("failed to rewind content for single-shot fallback: %w", seekErr) + } + return a.Upload(ctx, UploadRequest{ + FilePath: filePath, + Contents: io.NopCloser(content), + Overwrite: cfg.Overwrite, + }) + } + return err +} + +// UploadFromFile uploads a local file to the given path, automatically choosing +// between single-shot and multipart upload based on the file size. +func (a *FilesAPI) UploadFromFile(ctx context.Context, filePath string, sourcePath string, opts ...UploadOption) error { + f, err := os.Open(sourcePath) + if err != nil { + return fmt.Errorf("failed to open source file: %w", err) + } + defer f.Close() + + info, err := f.Stat() + if err != nil { + return fmt.Errorf("failed to stat source file: %w", err) + } + + return a.UploadWithChunking(ctx, filePath, f, info.Size(), opts...) +} diff --git a/service/files/ext_upload_test.go b/service/files/ext_upload_test.go new file mode 100644 index 000000000..ccfba3f9a --- /dev/null +++ b/service/files/ext_upload_test.go @@ -0,0 +1,601 @@ +package files + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOptimizePartSize(t *testing.T) { + const MiB = 1024 * 1024 + const GiB = 1024 * MiB + + tests := []struct { + name string + contentLength int64 + explicitPartSize int64 + wantPartSize int64 + wantBatchSize int + }{ + { + // 5 MiB / 10 MiB = 1 part; ceil(sqrt(1)) = 1 + name: "small file 5 MiB returns defaultPartSize batch 1", + contentLength: 5 * MiB, + wantPartSize: defaultPartSize, + wantBatchSize: 1, + }, + { + // 500 MiB / 10 MiB = 50 parts <= 100; ceil(sqrt(50)) = 8 + name: "500 MiB returns 10 MiB parts batch 8", + contentLength: 500 * MiB, + wantPartSize: 10 * MiB, + wantBatchSize: 8, + }, + { + // 5 GiB: 10 MiB = 512 parts (too many), 20 MiB = 256 (too many), + // 50 MiB = 103 (too many), 100 MiB = 52 <= 100; ceil(sqrt(52)) = 8 + name: "5 GiB returns 100 MiB parts batch 8", + contentLength: 5 * GiB, + wantPartSize: 100 * MiB, + wantBatchSize: 8, + }, + { + // Explicit 20 MiB for 500 MiB file: 500/20 = 25 parts; ceil(sqrt(25)) = 5 + name: "explicit part size 20 MiB for 500 MiB file batch 5", + contentLength: 500 * MiB, + explicitPartSize: 20 * MiB, + wantPartSize: 20 * MiB, + wantBatchSize: 5, + }, + { + // Unknown size (0): returns defaultPartSize, batch 1 + name: "unknown size returns defaultPartSize batch 1", + contentLength: 0, + wantPartSize: defaultPartSize, + wantBatchSize: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotPartSize, gotBatchSize := optimizePartSize(tc.contentLength, tc.explicitPartSize) + if gotPartSize != tc.wantPartSize { + t.Errorf("optimizePartSize(%d, %d) partSize = %d, want %d", + tc.contentLength, tc.explicitPartSize, gotPartSize, tc.wantPartSize) + } + if gotBatchSize != tc.wantBatchSize { + t.Errorf("optimizePartSize(%d, %d) batchSize = %d, want %d", + tc.contentLength, tc.explicitPartSize, gotBatchSize, tc.wantBatchSize) + } + }) + } +} + +// multipartMockServer simulates the multipart upload protocol for testing. +type multipartMockServer struct { + mu sync.Mutex + parts map[int][]byte + completed bool + completeParts []completeUploadPart + sessionToken string +} + +func newMultipartMockServer() *multipartMockServer { + return &multipartMockServer{ + parts: make(map[int][]byte), + sessionToken: "test-session-token-123", + } +} + +func (m *multipartMockServer) handler(srv *httptest.Server) http.Handler { + mux := http.NewServeMux() + + // Initiate upload / complete upload / single-shot PUT + mux.HandleFunc("/api/2.0/fs/files/", func(w http.ResponseWriter, r *http.Request) { + // Single-shot PUT upload (no action query parameter) + if r.Method == http.MethodPut { + data, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + m.mu.Lock() + m.parts[1] = data + m.completed = true + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + return + } + action := r.URL.Query().Get("action") + switch action { + case "initiate-upload": + w.Header().Set("Content-Type", "application/json") + resp := initiateUploadResponse{ + MultipartUpload: &multipartUploadSession{ + SessionToken: m.sessionToken, + }, + } + json.NewEncoder(w).Encode(resp) + case "complete-upload": + var req completeUploadRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + m.mu.Lock() + m.completed = true + m.completeParts = req.Parts + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + default: + http.Error(w, "unknown action", http.StatusBadRequest) + } + }) + + // Create upload part URLs + mux.HandleFunc("/api/2.0/fs/create-upload-part-urls", func(w http.ResponseWriter, r *http.Request) { + var req createUploadPartURLsRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + urls := make([]presignedURL, req.Count) + for i := 0; i < req.Count; i++ { + pn := req.StartPartNumber + i + urls[i] = presignedURL{ + URL: fmt.Sprintf("%s/upload-part/%d", srv.URL, pn), + PartNumber: pn, + Headers: []presignedHeader{ + {Name: "x-test-header", Value: "test-value"}, + }, + } + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(createUploadPartURLsResponse{UploadPartURLs: urls}) + }) + + // Upload part (presigned URL target) + mux.HandleFunc("/upload-part/", func(w http.ResponseWriter, r *http.Request) { + parts := strings.Split(strings.TrimPrefix(r.URL.Path, "/upload-part/"), "/") + partNum, err := strconv.Atoi(parts[0]) + if err != nil { + http.Error(w, "invalid part number", http.StatusBadRequest) + return + } + data, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + m.mu.Lock() + m.parts[partNum] = data + m.mu.Unlock() + + etag := fmt.Sprintf("etag-part-%d", partNum) + w.Header().Set("ETag", etag) + w.WriteHeader(http.StatusOK) + }) + + return mux +} + +func newTestFilesAPI(t *testing.T, serverURL string) *FilesAPI { + t.Helper() + cfg := &config.Config{ + Host: serverURL, + Token: "test-token", + } + err := cfg.EnsureResolved() + require.NoError(t, err) + databricksClient, err := client.New(cfg) + require.NoError(t, err) + return NewFiles(databricksClient) +} + +func TestUploadMultipart_FullFlow(t *testing.T) { + mock := newMultipartMockServer() + + // Create server with a temporary handler, then replace with the real one + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + srv.Config.Handler = mock.handler(srv) + + api := newTestFilesAPI(t, srv.URL) + + // Create 100 KiB of test content + contentSize := 100 * 1024 + content := make([]byte, contentSize) + for i := range content { + content[i] = byte(i % 256) + } + + cfg := &UploadConfig{ + PartSize: 30 * 1024, // 30 KiB parts + Parallelism: 2, + Overwrite: true, + } + + ctx := context.Background() + err := api.uploadMultipart(ctx, "/test/upload.bin", strings.NewReader(string(content)), int64(contentSize), cfg) + require.NoError(t, err) + + // Verify completion + mock.mu.Lock() + defer mock.mu.Unlock() + + assert.True(t, mock.completed, "upload should be completed") + assert.True(t, len(mock.parts) > 1, "should have multiple parts, got %d", len(mock.parts)) + + // Verify total bytes match + totalBytes := 0 + for _, data := range mock.parts { + totalBytes += len(data) + } + assert.Equal(t, contentSize, totalBytes, "total uploaded bytes should match content size") + + // 100 KiB / 30 KiB = 4 parts (30+30+30+10) + expectedParts := 4 + assert.Equal(t, expectedParts, len(mock.completeParts), "complete request should have correct number of parts") + + // Verify parts are sorted by part number + for i := 1; i < len(mock.completeParts); i++ { + assert.True(t, mock.completeParts[i].PartNumber > mock.completeParts[i-1].PartNumber, + "parts should be sorted by part number") + } +} + +func TestUploadWithChunking_SmallFile_UsesSingleShot(t *testing.T) { + var receivedBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPut && strings.HasPrefix(r.URL.Path, "/api/2.0/fs/files/") { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + receivedBody = body + w.WriteHeader(http.StatusOK) + return + } + http.Error(w, "not found", http.StatusNotFound) + })) + defer srv.Close() + + api := newTestFilesAPI(t, srv.URL) + + content := []byte("hello multipart!!") + reader := strings.NewReader(string(content)) + + ctx := context.Background() + err := api.UploadWithChunking(ctx, "/test/small.txt", reader, int64(len(content))) + require.NoError(t, err) + assert.Equal(t, content, receivedBody) +} + +func TestUploadOnePartWithRetry_RefreshesExpiredURL(t *testing.T) { + var uploadAttempts atomic.Int32 + + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + + mux := http.NewServeMux() + + // create-upload-part-urls — always returns a URL pointing to /upload-part/1 + mux.HandleFunc("/api/2.0/fs/create-upload-part-urls", func(w http.ResponseWriter, r *http.Request) { + var req createUploadPartURLsRequest + json.NewDecoder(r.Body).Decode(&req) + resp := createUploadPartURLsResponse{ + UploadPartURLs: []presignedURL{{ + URL: fmt.Sprintf("%s/upload-part/%d", srv.URL, req.StartPartNumber), + PartNumber: req.StartPartNumber, + }}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + + // upload-part — first attempt returns expired URL error, second succeeds + mux.HandleFunc("/upload-part/", func(w http.ResponseWriter, r *http.Request) { + n := uploadAttempts.Add(1) + if n == 1 { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("AccessDenied")) + return + } + w.Header().Set("ETag", "etag-success") + w.WriteHeader(http.StatusOK) + }) + + srv.Config.Handler = mux + api := newTestFilesAPI(t, srv.URL) + + data := strings.NewReader("test data") + ctx := context.Background() + etag, err := api.uploadOnePartWithRetry(ctx, "/test/file.bin", "session-tok", 1, data, int64(len("test data"))) + require.NoError(t, err) + assert.Equal(t, "etag-success", etag) + assert.Equal(t, int32(2), uploadAttempts.Load()) +} + +func TestUploadFromFile(t *testing.T) { + mock := newMultipartMockServer() + + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + srv.Config.Handler = mock.handler(srv) + + api := newTestFilesAPI(t, srv.URL) + + // Create a temp file with 100 KiB content + contentSize := 100 * 1024 + content := make([]byte, contentSize) + for i := range content { + content[i] = byte(i % 256) + } + + tmpFile, err := os.CreateTemp("", "upload-test-*") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + _, err = tmpFile.Write(content) + require.NoError(t, err) + require.NoError(t, tmpFile.Close()) + + ctx := context.Background() + err = api.UploadFromFile(ctx, "/test/fromfile.bin", tmpFile.Name(), + WithPartSize(25*1024), + WithParallelism(2), + WithOverwrite(true), + ) + require.NoError(t, err) + + mock.mu.Lock() + defer mock.mu.Unlock() + + // File is 100 KiB < 50 MiB threshold, so single-shot upload is used + assert.True(t, mock.completed, "upload should be completed") + assert.Equal(t, 1, len(mock.parts), "single-shot upload stores as one part") + assert.Equal(t, contentSize, len(mock.parts[1]), "uploaded bytes should match content size") + assert.Equal(t, content, mock.parts[1], "uploaded content should match") +} + +func TestUploadMultipart_FirstChunkFallback(t *testing.T) { + // Simulate: initiate-upload succeeds, but the first presigned URL upload + // returns 403 (Azure firewall). uploadMultipart should return + // errFallbackToSingleShot so the caller can retry with single-shot. + + mux := http.NewServeMux() + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + + // initiate-upload — succeed with a session token. + mux.HandleFunc("/api/2.0/fs/files/", func(w http.ResponseWriter, r *http.Request) { + action := r.URL.Query().Get("action") + if action == "initiate-upload" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(initiateUploadResponse{ + MultipartUpload: &multipartUploadSession{SessionToken: "tok"}, + }) + return + } + http.Error(w, "unexpected", http.StatusBadRequest) + }) + + // create-upload-part-urls — return a URL that will fail. + mux.HandleFunc("/api/2.0/fs/create-upload-part-urls", func(w http.ResponseWriter, r *http.Request) { + var req createUploadPartURLsRequest + json.NewDecoder(r.Body).Decode(&req) + resp := createUploadPartURLsResponse{ + UploadPartURLs: []presignedURL{{ + URL: fmt.Sprintf("%s/upload-part/%d", srv.URL, req.StartPartNumber), + PartNumber: req.StartPartNumber, + }}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + + // upload-part — always return 403 to simulate firewall block. + mux.HandleFunc("/upload-part/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("AccessDenied")) + }) + + // abort upload URL (best-effort cleanup). + mux.HandleFunc("/api/2.0/fs/create-abort-upload-url", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"abort_upload_url":{"url":""}}`)) + }) + + srv.Config.Handler = mux + api := newTestFilesAPI(t, srv.URL) + + content := strings.NewReader("test data for fallback") + cfg := &UploadConfig{PartSize: 1024, Parallelism: 1, Overwrite: true} + + err := api.uploadMultipart(context.Background(), "/test/fallback.bin", content, int64(len("test data for fallback")), cfg) + + var fallback *errFallbackToSingleShot + require.ErrorAs(t, err, &fallback, "should return errFallbackToSingleShot") + assert.Contains(t, fallback.reason.Error(), "AccessDenied") +} + +func TestUploadWithChunking_FallbackToSingleShot(t *testing.T) { + // End-to-end test: UploadWithChunking triggers multipart (content > threshold), + // first chunk fails, falls back to single-shot Upload which succeeds. + + var receivedBody []byte + var singleShotCalled atomic.Int32 + var initiateCount atomic.Int32 + + mux := http.NewServeMux() + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + + mux.HandleFunc("/api/2.0/fs/files/", func(w http.ResponseWriter, r *http.Request) { + action := r.URL.Query().Get("action") + switch { + case r.Method == http.MethodPut && action == "": + // Single-shot PUT fallback. + singleShotCalled.Add(1) + body, _ := io.ReadAll(r.Body) + receivedBody = body + w.WriteHeader(http.StatusOK) + case action == "initiate-upload": + initiateCount.Add(1) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(initiateUploadResponse{ + MultipartUpload: &multipartUploadSession{SessionToken: "tok"}, + }) + default: + http.Error(w, "unexpected", http.StatusBadRequest) + } + }) + + mux.HandleFunc("/api/2.0/fs/create-upload-part-urls", func(w http.ResponseWriter, r *http.Request) { + var req createUploadPartURLsRequest + json.NewDecoder(r.Body).Decode(&req) + resp := createUploadPartURLsResponse{ + UploadPartURLs: []presignedURL{{ + URL: fmt.Sprintf("%s/upload-part/%d", srv.URL, req.StartPartNumber), + PartNumber: req.StartPartNumber, + }}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + + // First chunk always fails. + mux.HandleFunc("/upload-part/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("AccessDenied")) + }) + + mux.HandleFunc("/api/2.0/fs/create-abort-upload-url", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"abort_upload_url":{"url":""}}`)) + }) + + srv.Config.Handler = mux + api := newTestFilesAPI(t, srv.URL) + + // Content must be >= minMultipartUploadSize to trigger multipart path. + // We use a small threshold override via the internal config. + // Instead, we call uploadMultipart + fallback directly through UploadWithChunking + // by crafting content that exceeds minMultipartUploadSize. + // For a practical test, we override the threshold by testing the full flow + // with a content size that's large enough. But 50 MiB is too large for a unit test. + // Instead, test the fallback path by calling uploadMultipart directly and + // verifying that UploadWithChunking handles the error correctly. + + // Direct test: call UploadWithChunking with a content size that looks large + // to pass the threshold. We fake contentLength to be above threshold. + contentData := "hello fallback data" + content := strings.NewReader(contentData) + fakeLength := int64(minMultipartUploadSize + 1) // Trick threshold check + + err := api.UploadWithChunking(context.Background(), "/test/fallback.bin", content, fakeLength, WithOverwrite(true)) + require.NoError(t, err) + assert.Equal(t, int32(1), singleShotCalled.Load(), "single-shot upload should have been called as fallback") + assert.Equal(t, int32(1), initiateCount.Load(), "multipart initiate should have been called once") + assert.Equal(t, []byte(contentData), receivedBody, "single-shot should have received the full content") +} + +func TestUploadMultipart_ContentLengthMismatch(t *testing.T) { + mock := newMultipartMockServer() + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + srv.Config.Handler = mock.handler(srv) + + api := newTestFilesAPI(t, srv.URL) + + content := strings.NewReader("short") + // Declare a much larger contentLength than actual data. + cfg := &UploadConfig{PartSize: 1024, Parallelism: 1, Overwrite: true} + + err := api.uploadMultipart(context.Background(), "/test/mismatch.bin", content, 999999, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "content length mismatch") +} + +func TestUploadMultipart_RejectsNonPositiveContentLength(t *testing.T) { + api := &FilesAPI{} + cfg := &UploadConfig{PartSize: 1024, Parallelism: 1} + + err := api.uploadMultipart(context.Background(), "/test/file.bin", strings.NewReader("data"), 0, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "contentLength must be positive") + + err = api.uploadMultipart(context.Background(), "/test/file.bin", strings.NewReader("data"), -1, cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "contentLength must be positive") +} + +func TestUploadWithChunking_LargeFileUsesMultipart(t *testing.T) { + mock := newMultipartMockServer() + srv := httptest.NewServer(http.NotFoundHandler()) + defer srv.Close() + srv.Config.Handler = mock.handler(srv) + + api := newTestFilesAPI(t, srv.URL) + + // Create content just above the multipart threshold. + // We can't use 50 MiB in a unit test, so we use a smaller content but + // fake the content length to exceed the threshold. The mock server will + // accept whatever we send. + contentSize := 100 * 1024 // 100 KiB actual data + content := make([]byte, contentSize) + for i := range content { + content[i] = byte(i % 256) + } + + ctx := context.Background() + // Use UploadWithChunking with contentLength > minMultipartUploadSize + // but actual data is smaller. The content length mismatch will cause + // an error because we validate it. Instead, let's test the multipart + // path directly through uploadMultipart, which we know works from + // TestUploadMultipart_FullFlow. The key thing to test here is that + // UploadWithChunking dispatches to multipart when length >= threshold. + + // We verify this by checking that initiate-upload was called (multipart + // path) rather than a direct PUT (single-shot). + // To do this properly without 50 MiB, we check mock state. + + // For the threshold test, we can just verify that small files go single-shot + // (already covered by TestUploadWithChunking_SmallFile_UsesSingleShot) + // and that the routing logic is correct by calling uploadMultipart directly. + // The full multipart flow is tested by TestUploadMultipart_FullFlow. + + // Test that uploadMultipart is called when length >= threshold by verifying + // the mock received multiple parts. + err := api.UploadWithChunking(ctx, "/test/large.bin", + strings.NewReader(string(content)), + int64(minMultipartUploadSize+1), // Declare as "large" to trigger multipart + WithPartSize(30*1024), + WithParallelism(2), + WithOverwrite(true), + ) + // This will fail with content length mismatch since actual < declared, + // but it proves the multipart path was taken (initiate was called). + require.Error(t, err) + assert.Contains(t, err.Error(), "content length mismatch") + + mock.mu.Lock() + defer mock.mu.Unlock() + // The multipart upload was initiated (proves the routing went to multipart). + assert.True(t, len(mock.parts) > 0, "multipart path should have been taken") +}