Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
184 changes: 102 additions & 82 deletions cmd/root_cmd/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,112 @@ func runPush(cmdCtx context.Context, in *pushInputs) error {
return err
}

codeResp, codePushed, err := s.pushCodeAndWaitForParsing()
versionId, codeSnapshotCid, codePushed, err := s.runPushFlow()
if err != nil {
return err
}

if err := s.applyModelToVersion(codeResp.VersionId); err != nil {
return err
s.sendSuccessEvent(versionId, codeSnapshotCid, codePushed)

return s.triggerEvaluate(versionId, dispatch)
}

// runPushFlow pushes the version and returns its id and code-snapshot id.
// A new model goes through the combined push job (code parse + model import in
// one). Overwriting a version that already has a model goes through the
// override-push job (re-parse code + re-validate the existing model, no upload).
func (s *pushState) runPushFlow() (versionId, codeSnapshotCid string, codePushed bool, err error) {
if s.needsNewModel() {
return s.runCombinedPush()
}
return s.runOverridePush()
}

func (s *pushState) runOverridePush() (versionId, codeSnapshotCid string, codePushed bool, err error) {
in := s.inputs

closeBundle, tarGzFile, err := code.BundleCodeIntoTempFile(".", s.workspaceConfig)
if err != nil {
return "", "", false, s.fail("bundle_code", err)
}
defer closeBundle()

s.sendSuccessEvent(codeResp, codePushed)
fileStat, err := tarGzFile.Stat()
if err != nil {
return "", "", false, s.fail("bundle_code", fmt.Errorf("failed to get file stat: %w", err))
}

return s.triggerEvaluate(codeResp.VersionId, dispatch)
pushResp, err := code.PushOverride(
s.ctx, tarGzFile, fileStat.Size(),
s.workspaceConfig.EntryFile, in.secretId, in.pythonVersion,
s.projectId(), in.branch, s.overwriteVersion.VersionId,
)
if err != nil {
return "", "", false, s.fail("push_override", err)
}

s.properties["code_snapshot_id"] = pushResp.CodeSnapshot.Cid
s.properties["version_id"] = pushResp.VersionId
s.properties["job_id"] = pushResp.JobId

if in.noWait {
log.Info("Starting push job. JobId: ", pushResp.JobId)
return pushResp.VersionId, pushResp.CodeSnapshot.Cid, true, nil
}

if err := model.WaitForPushJob(s.ctx, s.projectId(), pushResp.VersionId, pushResp.JobId); err != nil {
return pushResp.VersionId, pushResp.CodeSnapshot.Cid, true, s.fail("push_job", err)
}
return pushResp.VersionId, pushResp.CodeSnapshot.Cid, true, nil
}

func (s *pushState) runCombinedPush() (versionId, codeSnapshotCid string, codePushed bool, err error) {
in := s.inputs

closeBundle, tarGzFile, err := code.BundleCodeIntoTempFile(".", s.workspaceConfig)
if err != nil {
return "", "", false, s.fail("bundle_code", err)
}
defer closeBundle()

modelInfo, err := model.PrepareImportModelFromFilePath(s.ctx, s.projectId(), in.modelPath, in.transformInput, in.modelType)
if err != nil {
return "", "", false, s.fail("prepare_model", err)
}

fileStat, err := tarGzFile.Stat()
if err != nil {
return "", "", false, s.fail("bundle_code", fmt.Errorf("failed to get file stat: %w", err))
}

overwriteVersionId := ""
if s.overwriteVersion != nil {
overwriteVersionId = s.overwriteVersion.VersionId
}

pushResp, err := code.PushCodeAndModel(
s.ctx, tarGzFile, fileStat.Size(),
s.workspaceConfig.EntryFile, in.secretId, in.pythonVersion,
in.modelVersionName, s.projectId(), in.branch, overwriteVersionId,
*modelInfo,
)
if err != nil {
return "", "", false, s.fail("push", err)
}

s.properties["code_snapshot_id"] = pushResp.CodeSnapshot.Cid
s.properties["version_id"] = pushResp.VersionId
s.properties["job_id"] = pushResp.JobId

if in.noWait {
log.Info("Starting push job. JobId: ", pushResp.JobId)
return pushResp.VersionId, pushResp.CodeSnapshot.Cid, true, nil
}

if err := model.WaitForPushJob(s.ctx, s.projectId(), pushResp.VersionId, pushResp.JobId); err != nil {
return pushResp.VersionId, pushResp.CodeSnapshot.Cid, true, s.fail("push_job", err)
}
return pushResp.VersionId, pushResp.CodeSnapshot.Cid, true, nil
}

func validatePushInputs(in *pushInputs) error {
Expand Down Expand Up @@ -374,84 +468,10 @@ func (s *pushState) askOrDefaultBatchSize() (int, error) {
return chosen, nil
}

func (s *pushState) pushCodeAndWaitForParsing() (*tensorleapapi.PushCodeSnapshotResponse, bool, error) {
in := s.inputs
closeBundle, tarGzFile, err := code.BundleCodeIntoTempFile(".", s.workspaceConfig)
if err != nil {
return nil, false, s.fail("bundle_code", err)
}
defer closeBundle()

overwriteVersionId := ""
if s.overwriteVersion != nil {
overwriteVersionId = s.overwriteVersion.VersionId
}
pushed, codeResp, err := code.PushCode(s.ctx, tarGzFile, s.workspaceConfig.EntryFile, in.secretId, in.pythonVersion, in.modelVersionName, s.projectId(), in.branch, overwriteVersionId)
if err != nil {
s.tagCodeResp(codeResp)
return codeResp, false, s.fail("push_code", err)
}

if pushed || !code.IsCodeEnded(&codeResp.CodeSnapshot) {
ok, waitErr := code.WaitForCodeIntegrationStatus(s.ctx, s.projectId(), codeResp.CodeSnapshot.Cid)
if waitErr != nil {
s.tagCodeResp(codeResp)
return codeResp, pushed, s.fail("wait_for_code_parsing", waitErr)
}
if !ok {
s.tagCodeResp(codeResp)
return codeResp, pushed, s.fail("code_parsing", fmt.Errorf("code parsing failed"))
}
log.Info("Code parsed successfully")
} else if code.IsCodeParseFailed(&codeResp.CodeSnapshot) {
s.tagCodeResp(codeResp)
return codeResp, pushed, s.fail("previous_code_parsing_failed", fmt.Errorf("latest code parsing failed, add --force to push anyway"))
}
return codeResp, pushed, nil
}

func (s *pushState) tagCodeResp(codeResp *tensorleapapi.PushCodeSnapshotResponse) {
if codeResp == nil {
return
}
s.properties["code_snapshot_id"] = codeResp.CodeSnapshot.Cid
s.properties["version_id"] = codeResp.VersionId
}

func (s *pushState) applyModelToVersion(versionId string) error {
in := s.inputs
if !s.isOverwrite {
importModelInfo, err := model.PrepareImportModelFromFilePath(s.ctx, s.projectId(), in.modelPath, in.transformInput, in.modelType)
if err != nil {
return err
}
if _, err = model.ImportModel(s.ctx, s.projectId(), versionId, importModelInfo, !in.noWait); err != nil {
s.properties["code_snapshot_id"] = versionId
s.properties["version_id"] = versionId
return s.fail("import_model", err)
}
return nil
}

var importModelInfo *tensorleapapi.ImportModelInfo
if !s.overwriteVersion.HasModel && !s.overwriteVersion.HasUploadedModel {
var err error
importModelInfo, err = model.PrepareImportModelFromFilePath(s.ctx, s.projectId(), in.modelPath, in.transformInput, in.modelType)
if err != nil {
return err
}
}
if _, err := model.OverrideModel(s.ctx, s.projectId(), versionId, !in.noWait, importModelInfo); err != nil {
s.properties["version_id"] = versionId
return s.fail("override_model", err)
}
return nil
}

func (s *pushState) sendSuccessEvent(codeResp *tensorleapapi.PushCodeSnapshotResponse, codePushed bool) {
func (s *pushState) sendSuccessEvent(versionId, codeSnapshotCid string, codePushed bool) {
in := s.inputs
s.properties["code_snapshot_id"] = codeResp.CodeSnapshot.Cid
s.properties["version_id"] = codeResp.VersionId
s.properties["code_snapshot_id"] = codeSnapshotCid
s.properties["version_id"] = versionId
s.properties["project_id"] = s.projectId()
s.properties["final_secret_id"] = in.secretId
s.properties["final_python_version"] = in.pythonVersion
Expand Down
73 changes: 63 additions & 10 deletions pkg/code/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ func GetCodeSnapshot(ctx context.Context, projectId, id string) (*CodeSnapshot,
return &res.CodeSnapshot, err
}

func PushCodeSnapshot(
// PushCodeAndModel uploads the code bundle and triggers the combined push job
// (code parse + model import in one job) via the new push endpoint, replacing
// the separate pushCodeSnapshot + importModel calls.
func PushCodeAndModel(
ctx context.Context, tarGzFile io.Reader, fileSize int64,
entryFile, secretId, pythonVersion,
versionName, projectId, branch string,
overwriteVersionId string,
) (*tensorleapapi.PushCodeSnapshotResponse, error) {
modelInfo tensorleapapi.ImportModelInfo,
) (*tensorleapapi.PushResponse, error) {

uploadUrl, err := GetCodeSnapshotUploadUrl(ctx, projectId)
if err != nil {
Expand All @@ -38,32 +42,81 @@ func PushCodeSnapshot(
return nil, err
}

saveCodeSnapshotParams := *tensorleapapi.NewPushCodeSnapshotParams(
pushParams := *tensorleapapi.NewPushParams(
projectId,
uploadUrl,
entryFile,
versionName,
modelInfo,
)

if len(overwriteVersionId) > 0 {
saveCodeSnapshotParams.SetOverwriteVersionId(overwriteVersionId)
pushParams.SetOverwriteVersionId(overwriteVersionId)
}

if len(pythonVersion) > 0 {
saveCodeSnapshotParams.GenericBaseImageType = &pythonVersion
pushParams.GenericBaseImageType = &pythonVersion
}

if len(branch) > 0 {
saveCodeSnapshotParams.SetBranchName(branch)
pushParams.SetBranchName(branch)
}

if len(secretId) > 0 {
saveCodeSnapshotParams.SecretManagerId = &secretId
pushParams.SecretManagerId = &secretId
}

log.Info("Pushing code snapshot...")
result, response, err := api.ApiClient.PushCodeSnapshot(ctx).
PushCodeSnapshotParams(saveCodeSnapshotParams).
log.Info("Pushing code and model...")
result, response, err := api.ApiClient.Push(ctx).
PushParams(pushParams).
Execute()
if err = api.CheckRes(response, err); err != nil {
return nil, err
}

return result, nil
}

// PushOverride re-pushes code to an existing version and re-validates its
// existing model via the override-push endpoint — no new model upload, no name
// (both come from the overwritten version).
func PushOverride(
ctx context.Context, tarGzFile io.Reader, fileSize int64,
entryFile, secretId, pythonVersion, projectId, branch string,
overwriteVersionId string,
) (*tensorleapapi.PushResponse, error) {

uploadUrl, err := GetCodeSnapshotUploadUrl(ctx, projectId)
if err != nil {
return nil, err
}

if err := api.UploadFile(uploadUrl, tarGzFile, fileSize); err != nil {
return nil, err
}

params := *tensorleapapi.NewPushOverrideParams(
projectId,
uploadUrl,
entryFile,
overwriteVersionId,
)

if len(pythonVersion) > 0 {
params.GenericBaseImageType = &pythonVersion
}

if len(branch) > 0 {
params.SetBranchName(branch)
}

if len(secretId) > 0 {
params.SecretManagerId = &secretId
}

log.Info("Pushing code (override)...")
result, response, err := api.ApiClient.PushOverride(ctx).
PushOverrideParams(params).
Execute()
if err = api.CheckRes(response, err); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/code/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func CollectCodeSnapshotParserErr(civ *CodeSnapshot) *CodeSnapshotParserErrs {
errs.printLog = civ.ParseResult.SetupStatus.GetPrintLog()
}

hasErrors := lo.SomeBy(civ.ParseResult.SetupStatus.BindersStatus, func(binderStatus tensorleapapi.DatasetTestResultPayload) bool {
hasErrors := lo.SomeBy(civ.ParseResult.SetupStatus.BindersStatus, func(binderStatus tensorleapapi.CodeTestResultPayload) bool {
return !binderStatus.IsPassed && len(binderStatus.Display) > 0
})

Expand Down
20 changes: 0 additions & 20 deletions pkg/code/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,26 +553,6 @@ func isExcluded(path string, excludePatterns []string) bool {
return false
}

func PushCode(ctx context.Context, tarGzFile *os.File, entryFile, secretId, pythonVersion, versionName, projectId, branch string, overwriteVersionId string) (pushed bool, current *tensorleapapi.PushCodeSnapshotResponse, err error) {

fileStat, err := tarGzFile.Stat()
if err != nil {
return false, nil, fmt.Errorf("failed to get file stat: %v", err)
}

codeSnapshot, err := PushCodeSnapshot(
ctx, tarGzFile, fileStat.Size(),
entryFile, secretId, pythonVersion, versionName,
projectId,
branch,
overwriteVersionId,
)
if err != nil {
return false, nil, err
}
return true, codeSnapshot, nil
}

func CompareCodeVersion(ctx context.Context, compareVersion *CodeSnapshot, tarGzFile *os.File, entryFile, secretId, pythonVersion string) (bool, error) {

if isCodeSnapshotEmpty(compareVersion) {
Expand Down
Loading
Loading