diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index 8a62fd6f50..8be7020112 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -14,12 +14,12 @@ import ( "github.com/sirupsen/logrus" "github.com/urfave/cli/v3" - "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/executor" - linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/consts" + "github.com/dstackai/dstack/runner/internal/common/log" "github.com/dstackai/dstack/runner/internal/runner/api" - "github.com/dstackai/dstack/runner/internal/ssh" + "github.com/dstackai/dstack/runner/internal/runner/executor" + linuxuser "github.com/dstackai/dstack/runner/internal/runner/linux/user" + "github.com/dstackai/dstack/runner/internal/runner/ssh" ) // Version is a build-time variable. The value is overridden by ldflags. diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 644d7e80e8..c696bd4673 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -15,9 +15,9 @@ import ( "github.com/sirupsen/logrus" "github.com/urfave/cli/v3" - "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/consts" + "github.com/dstackai/dstack/runner/internal/common/gpu" + "github.com/dstackai/dstack/runner/internal/common/log" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/api" "github.com/dstackai/dstack/runner/internal/shim/components" @@ -236,7 +236,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) var dcgmExporter *dcgm.DCGMExporter var dcgmWrapper dcgm.DCGMWrapperInterface - if common.GetGpuVendor() == common.GpuVendorNvidia { + if gpu.GetGpuVendor() == gpu.GpuVendorNvidia { dcgmExporterPath, err := dcgm.GetDCGMExporterExecPath(ctx) if err == nil { interval := time.Duration(args.DCGMExporter.Interval * int(time.Millisecond)) diff --git a/runner/internal/api/common.go b/runner/internal/common/api/api.go similarity index 98% rename from runner/internal/api/common.go rename to runner/internal/common/api/api.go index 52fa886a0f..85cab57164 100644 --- a/runner/internal/api/common.go +++ b/runner/internal/common/api/api.go @@ -10,7 +10,7 @@ import ( "github.com/golang/gddo/httputil/header" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) type Error struct { diff --git a/runner/consts/consts.go b/runner/internal/common/consts/consts.go similarity index 100% rename from runner/consts/consts.go rename to runner/internal/common/consts/consts.go diff --git a/runner/internal/common/gpu.go b/runner/internal/common/gpu/gpu.go similarity index 98% rename from runner/internal/common/gpu.go rename to runner/internal/common/gpu/gpu.go index 045cc773be..72ae83bb56 100644 --- a/runner/internal/common/gpu.go +++ b/runner/internal/common/gpu/gpu.go @@ -1,4 +1,4 @@ -package common +package gpu import ( "errors" diff --git a/runner/internal/common/interpolator.go b/runner/internal/common/interpolator.go deleted file mode 100644 index 84597df7fa..0000000000 --- a/runner/internal/common/interpolator.go +++ /dev/null @@ -1,67 +0,0 @@ -package common - -import ( - "context" - "fmt" - "strings" - - "github.com/dstackai/dstack/runner/internal/log" -) - -const ( - PatternOpening = "${{" - PatternClosing = "}}" -) - -type VariablesInterpolator struct { - Variables map[string]string -} - -func (vi *VariablesInterpolator) Add(namespace string, vars map[string]string) { - if vi.Variables == nil { - vi.Variables = make(map[string]string, len(vars)) - } - for k, v := range vars { - vi.Variables[fmt.Sprintf("%s.%s", namespace, k)] = v - } -} - -func (vi *VariablesInterpolator) Interpolate(ctx context.Context, s string) (string, error) { - log.Trace(ctx, "Interpolating", "s", s) - var sb strings.Builder - - start := 0 - for start < len(s) { - dollar := IndexWithOffset(s, "$", start) - if dollar == -1 || dollar == len(s)-1 { - sb.WriteString(s[start:]) - break - } - if s[dollar+1] == '$' { // $$ = escaped $ - sb.WriteString(s[start : dollar+1]) - start = dollar + 2 - continue - } - - opening := IndexWithOffset(s, PatternOpening, start) - if opening == -1 { - sb.WriteString(s[start:]) - break - } - sb.WriteString(s[start:opening]) - closing := IndexWithOffset(s, PatternClosing, opening) - if closing == -1 { - return "", fmt.Errorf("no pattern closing: %s", s[opening:]) - } - - name := strings.TrimSpace(s[opening+len(PatternOpening) : closing]) - value, ok := vi.Variables[name] - if ok { - sb.WriteString(value) - } else { - log.Warning(ctx, "Variable is missing", "name", name) - } - start = closing + len(PatternClosing) - } - return sb.String(), nil -} diff --git a/runner/internal/common/interpolator_test.go b/runner/internal/common/interpolator_test.go deleted file mode 100644 index e14a248744..0000000000 --- a/runner/internal/common/interpolator_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package common - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestPlainText(t *testing.T) { - var vi VariablesInterpolator - s := "plain text" - result, err := vi.Interpolate(context.Background(), s) - assert.Equal(t, nil, err) - assert.Equal(t, s, result) -} - -func TestMissingVariable(t *testing.T) { - var vi VariablesInterpolator - result, err := vi.Interpolate(context.Background(), "${{ VAR_NAME }} is here") - assert.Equal(t, nil, err) - assert.Equal(t, " is here", result) -} - -func TestDollarEscape(t *testing.T) { - var vi VariablesInterpolator - result, err := vi.Interpolate(context.Background(), "it is not a variable $$!") - assert.Equal(t, nil, err) - assert.Equal(t, "it is not a variable $!", result) -} - -func TestDollarWithoutEscape(t *testing.T) { - var vi VariablesInterpolator - result, err := vi.Interpolate(context.Background(), "it is not a variable $!") - assert.Equal(t, nil, err) - assert.Equal(t, "it is not a variable $!", result) -} - -func TestEscapeOpening(t *testing.T) { - var vi VariablesInterpolator - result, err := vi.Interpolate(context.Background(), "$${{ VAR_NAME }}") - assert.Equal(t, nil, err) - assert.Equal(t, "${{ VAR_NAME }}", result) -} - -func TestWithoutClosing(t *testing.T) { - var vi VariablesInterpolator - _, err := vi.Interpolate(context.Background(), "the end ${{") - assert.NotEqual(t, nil, err) -} - -func TestUnexpectedEOL(t *testing.T) { - var vi VariablesInterpolator - _, err := vi.Interpolate(context.Background(), "the end ${{ VAR }") - assert.NotEqual(t, nil, err) -} - -func TestSecrets(t *testing.T) { - var vi VariablesInterpolator - vi.Add("secrets", map[string]string{"user": "qwerty"}) - result, err := vi.Interpolate(context.Background(), "${{ secrets.user }}") - assert.Equal(t, nil, err) - assert.Equal(t, "qwerty", result) -} diff --git a/runner/internal/log/log.go b/runner/internal/common/log/log.go similarity index 100% rename from runner/internal/log/log.go rename to runner/internal/common/log/log.go diff --git a/runner/internal/common/string.go b/runner/internal/common/string.go deleted file mode 100644 index 28a5ae0756..0000000000 --- a/runner/internal/common/string.go +++ /dev/null @@ -1,11 +0,0 @@ -package common - -import "strings" - -func IndexWithOffset(hay string, needle string, start int) int { - idx := strings.Index(hay[start:], needle) - if idx < 0 { - return -1 - } - return start + idx -} diff --git a/runner/internal/types/types.go b/runner/internal/common/types/types.go similarity index 72% rename from runner/internal/types/types.go rename to runner/internal/common/types/types.go index e8a9519eb4..b7f6c6fd3a 100644 --- a/runner/internal/types/types.go +++ b/runner/internal/common/types/types.go @@ -11,13 +11,3 @@ const ( TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server" TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded" ) - -type JobState string - -const ( - JobStateDone JobState = "done" - JobStateFailed JobState = "failed" - JobStateRunning JobState = "running" - JobStateTerminated JobState = "terminated" - JobStateTerminating JobState = "terminating" -) diff --git a/runner/internal/common/utils.go b/runner/internal/common/utils/utils.go similarity index 95% rename from runner/internal/common/utils.go rename to runner/internal/common/utils/utils.go index 5be68edf70..5bfc17d867 100644 --- a/runner/internal/common/utils.go +++ b/runner/internal/common/utils/utils.go @@ -1,4 +1,4 @@ -package common +package utils import ( "context" @@ -7,7 +7,7 @@ import ( "path" "slices" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) func PathExists(pth string) (bool, error) { diff --git a/runner/internal/common/utils_test.go b/runner/internal/common/utils/utils_test.go similarity index 99% rename from runner/internal/common/utils_test.go rename to runner/internal/common/utils/utils_test.go index 5fe780d503..f38ac57925 100644 --- a/runner/internal/common/utils_test.go +++ b/runner/internal/common/utils/utils_test.go @@ -1,4 +1,4 @@ -package common +package utils import ( "context" diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 4d1c7daf54..34220acc6e 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -11,10 +11,10 @@ import ( "net/http" "strconv" - "github.com/dstackai/dstack/runner/internal/api" - "github.com/dstackai/dstack/runner/internal/executor" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/common/api" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/runner/executor" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) // TODO: set some reasonable value; (optional) make configurable diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index ba577d1a5b..11b76d887e 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -7,10 +7,10 @@ import ( _ "net/http/pprof" "time" - "github.com/dstackai/dstack/runner/internal/api" - "github.com/dstackai/dstack/runner/internal/executor" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/metrics" + "github.com/dstackai/dstack/runner/internal/common/api" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/runner/executor" + "github.com/dstackai/dstack/runner/internal/runner/metrics" ) type Server struct { diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go index bc6e476c0e..3229701a68 100644 --- a/runner/internal/runner/api/ws.go +++ b/runner/internal/runner/api/ws.go @@ -8,7 +8,7 @@ import ( "github.com/gorilla/websocket" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) type logsWsRequestParams struct { diff --git a/runner/internal/connections/connections.go b/runner/internal/runner/connections/connections.go similarity index 98% rename from runner/internal/connections/connections.go rename to runner/internal/runner/connections/connections.go index 37aedad7a2..4a56a6f172 100644 --- a/runner/internal/connections/connections.go +++ b/runner/internal/runner/connections/connections.go @@ -8,7 +8,7 @@ import ( "github.com/prometheus/procfs" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) const connStateEstablished = 1 diff --git a/runner/internal/connections/connections_test.go b/runner/internal/runner/connections/connections_test.go similarity index 100% rename from runner/internal/connections/connections_test.go rename to runner/internal/runner/connections/connections_test.go diff --git a/runner/internal/executor/base.go b/runner/internal/runner/executor/base.go similarity index 75% rename from runner/internal/executor/base.go rename to runner/internal/runner/executor/base.go index fac1266fb0..b8093e5e72 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/runner/executor/base.go @@ -4,8 +4,8 @@ import ( "context" "io" - "github.com/dstackai/dstack/runner/internal/schemas" - "github.com/dstackai/dstack/runner/internal/types" + "github.com/dstackai/dstack/runner/internal/common/types" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) type Executor interface { @@ -15,10 +15,10 @@ type Executor interface { GetJobInfo(ctx context.Context) (username string, workingDir string, err error) Run(ctx context.Context) error SetJob(job schemas.SubmitBody) - SetJobState(ctx context.Context, state types.JobState) + SetJobState(ctx context.Context, state schemas.JobState) SetJobStateWithTerminationReason( ctx context.Context, - state types.JobState, + state schemas.JobState, terminationReason types.TerminationReason, terminationMessage string, ) diff --git a/runner/internal/executor/env.go b/runner/internal/runner/executor/env.go similarity index 100% rename from runner/internal/executor/env.go rename to runner/internal/runner/executor/env.go diff --git a/runner/internal/executor/env_test.go b/runner/internal/runner/executor/env_test.go similarity index 100% rename from runner/internal/executor/env_test.go rename to runner/internal/runner/executor/env_test.go diff --git a/runner/internal/executor/executor.go b/runner/internal/runner/executor/executor.go similarity index 93% rename from runner/internal/executor/executor.go rename to runner/internal/runner/executor/executor.go index 311eddaa10..98289eb4ec 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/runner/executor/executor.go @@ -24,15 +24,15 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/sys/unix" - "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/connections" - cap "github.com/dstackai/dstack/runner/internal/linux/capabilities" - linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/schemas" - "github.com/dstackai/dstack/runner/internal/ssh" - "github.com/dstackai/dstack/runner/internal/types" + "github.com/dstackai/dstack/runner/internal/common/consts" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/common/types" + "github.com/dstackai/dstack/runner/internal/common/utils" + "github.com/dstackai/dstack/runner/internal/runner/connections" + cap "github.com/dstackai/dstack/runner/internal/runner/linux/capabilities" + linuxuser "github.com/dstackai/dstack/runner/internal/runner/linux/user" + "github.com/dstackai/dstack/runner/internal/runner/schemas" + "github.com/dstackai/dstack/runner/internal/runner/ssh" ) // TODO: Tune these parameters for optimal experience/performance @@ -164,7 +164,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { jobLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerJobLogFileName)) if err != nil { - ex.SetJobState(ctx, types.JobStateFailed) + ex.SetJobState(ctx, schemas.JobStateFailed) return fmt.Errorf("create job log file: %w", err) } defer func() { _ = jobLogFile.Close() }() @@ -173,7 +173,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { // recover goes after postRun(), which closes runnerLogFile, to keep the log if r := recover(); r != nil { log.Error(ctx, "Executor PANIC", "err", r) - ex.SetJobState(ctx, types.JobStateFailed) + ex.SetJobState(ctx, schemas.JobStateFailed) err = fmt.Errorf("recovered: %v", r) } // no more logs will be written after this @@ -211,7 +211,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { if err := ex.setupRepo(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, - types.JobStateFailed, + schemas.JobStateFailed, types.TerminationReasonContainerExitedWithError, fmt.Sprintf("Failed to set up the repo (%s)", err), ) @@ -221,7 +221,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { if err := ex.setupFiles(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, - types.JobStateFailed, + schemas.JobStateFailed, types.TerminationReasonExecutorError, fmt.Sprintf("Failed to set up files (%s)", err), ) @@ -232,7 +232,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { go ex.connectionTracker.Track(connectionTrackerTicker.C) defer ex.connectionTracker.Stop() - ex.SetJobState(ctx, types.JobStateRunning) + ex.SetJobState(ctx, schemas.JobStateRunning) timeoutCtx := ctx var cancelTimeout context.CancelFunc if ex.jobSpec.MaxDuration != 0 { @@ -243,7 +243,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { select { case <-ctx.Done(): log.Error(ctx, "Job canceled") - ex.SetJobState(ctx, types.JobStateTerminated) + ex.SetJobState(ctx, schemas.JobStateTerminated) return fmt.Errorf("job canceled: %w", err) default: } @@ -253,7 +253,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { log.Error(ctx, "Max duration exceeded", "max_duration", ex.jobSpec.MaxDuration) ex.SetJobStateWithTerminationReason( ctx, - types.JobStateTerminated, + schemas.JobStateTerminated, types.TerminationReasonMaxDurationExceeded, "Max duration exceeded", ) @@ -265,14 +265,14 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { log.Error(ctx, "Exec failed", "err", err) var exitError *exec.ExitError if errors.As(err, &exitError) { - ex.SetJobStateWithExitStatus(ctx, types.JobStateFailed, exitError.ExitCode()) + ex.SetJobStateWithExitStatus(ctx, schemas.JobStateFailed, exitError.ExitCode()) } else { - ex.SetJobState(ctx, types.JobStateFailed) + ex.SetJobState(ctx, schemas.JobStateFailed) } return fmt.Errorf("exec job failed: %w", err) } - ex.SetJobStateWithExitStatus(ctx, types.JobStateDone, 0) + ex.SetJobStateWithExitStatus(ctx, schemas.JobStateDone, 0) return nil } @@ -286,12 +286,12 @@ func (ex *RunExecutor) SetJob(body schemas.SubmitBody) { ex.state = WaitCode } -func (ex *RunExecutor) SetJobState(ctx context.Context, state types.JobState) { +func (ex *RunExecutor) SetJobState(ctx context.Context, state schemas.JobState) { ex.SetJobStateWithTerminationReason(ctx, state, "", "") } func (ex *RunExecutor) SetJobStateWithTerminationReason( - ctx context.Context, state types.JobState, terminationReason types.TerminationReason, terminationMessage string, + ctx context.Context, state schemas.JobState, terminationReason types.TerminationReason, terminationMessage string, ) { ex.mu.Lock() ex.jobStateHistory = append( @@ -311,7 +311,7 @@ func (ex *RunExecutor) SetJobStateWithTerminationReason( } func (ex *RunExecutor) SetJobStateWithExitStatus( - ctx context.Context, state types.JobState, exitStatus int, + ctx context.Context, state schemas.JobState, exitStatus int, ) { ex.mu.Lock() ex.jobStateHistory = append( @@ -343,7 +343,7 @@ func (ex *RunExecutor) preRun(ctx context.Context) error { // logging is required for the subsequent setJob{User,WorkingDir} calls runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName)) if err != nil { - ex.SetJobState(ctx, types.JobStateFailed) + ex.SetJobState(ctx, schemas.JobStateFailed) return fmt.Errorf("create runner log file: %w", err) } ex.runnerLogFile = runnerLogFile @@ -358,7 +358,7 @@ func (ex *RunExecutor) preRun(ctx context.Context) error { if err := ex.setJobUser(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, - types.JobStateFailed, + schemas.JobStateFailed, types.TerminationReasonExecutorError, fmt.Sprintf("Failed to set job user (%s)", err), ) @@ -367,7 +367,7 @@ func (ex *RunExecutor) preRun(ctx context.Context) error { if err := ex.setJobWorkingDir(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, - types.JobStateFailed, + schemas.JobStateFailed, types.TerminationReasonExecutorError, fmt.Sprintf("Failed to set job working dir (%s)", err), ) @@ -399,7 +399,7 @@ func (ex *RunExecutor) setJobWorkingDir(ctx context.Context) error { return fmt.Errorf("get working directory: %w", err) } } else { - ex.jobWorkingDir, err = common.ExpandPath(*ex.jobSpec.WorkingDir, "", ex.jobUser.HomeDir) + ex.jobWorkingDir, err = utils.ExpandPath(*ex.jobSpec.WorkingDir, "", ex.jobUser.HomeDir) if err != nil { return fmt.Errorf("expand working dir path: %w", err) } @@ -508,7 +508,7 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error } cmd.WaitDelay = ex.killDelay // kills the process if it doesn't exit in time - if err := common.MkdirAll(ctx, ex.jobWorkingDir, ex.jobUser.Uid, ex.jobUser.Gid, 0o755); err != nil { + if err := utils.MkdirAll(ctx, ex.jobWorkingDir, ex.jobUser.Uid, ex.jobUser.Gid, 0o755); err != nil { return fmt.Errorf("create working directory: %w", err) } cmd.Dir = ex.jobWorkingDir @@ -636,7 +636,7 @@ func (ex *RunExecutor) setupGitCredentials(ctx context.Context) (func(), error) if _, err := os.Stat(hostsPath); err == nil { return nil, fmt.Errorf("hosts.yml file already exists") } - if err := common.MkdirAll(ctx, filepath.Dir(hostsPath), ex.jobUser.Uid, ex.jobUser.Gid, 0o700); err != nil { + if err := utils.MkdirAll(ctx, filepath.Dir(hostsPath), ex.jobUser.Uid, ex.jobUser.Gid, 0o700); err != nil { return nil, fmt.Errorf("create gh config directory: %w", err) } log.Info(ctx, "Writing OAuth token", "path", hostsPath) diff --git a/runner/internal/executor/executor_test.go b/runner/internal/runner/executor/executor_test.go similarity index 98% rename from runner/internal/executor/executor_test.go rename to runner/internal/runner/executor/executor_test.go index 105493e301..915cca35a6 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/runner/executor/executor_test.go @@ -14,8 +14,8 @@ import ( "testing" "time" - linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" - "github.com/dstackai/dstack/runner/internal/schemas" + linuxuser "github.com/dstackai/dstack/runner/internal/runner/linux/user" + "github.com/dstackai/dstack/runner/internal/runner/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/runner/internal/executor/files.go b/runner/internal/runner/executor/files.go similarity index 91% rename from runner/internal/executor/files.go rename to runner/internal/runner/executor/files.go index 6b992ce2c1..f61b9e2429 100644 --- a/runner/internal/executor/files.go +++ b/runner/internal/runner/executor/files.go @@ -12,8 +12,8 @@ import ( "github.com/codeclysm/extract/v4" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/common/utils" ) var renameRegex = regexp.MustCompile(`^([^/]*)(/|$)`) @@ -62,12 +62,12 @@ func (ex *RunExecutor) setupFiles(ctx context.Context) error { func extractFileArchive(ctx context.Context, archivePath string, destPath string, baseDir string, homeDir string, uid int, gid int) error { log.Trace(ctx, "Extracting file archive", "archive", archivePath, "dest", destPath, "base", baseDir, "home", homeDir) - destPath, err := common.ExpandPath(destPath, baseDir, homeDir) + destPath, err := utils.ExpandPath(destPath, baseDir, homeDir) if err != nil { return fmt.Errorf("expand destination path: %w", err) } destBase, destName := path.Split(destPath) - if err := common.MkdirAll(ctx, destBase, uid, gid, 0o755); err != nil { + if err := utils.MkdirAll(ctx, destBase, uid, gid, 0o755); err != nil { return fmt.Errorf("create destination directory: %w", err) } if err := os.RemoveAll(destPath); err != nil { diff --git a/runner/internal/executor/lock.go b/runner/internal/runner/executor/lock.go similarity index 100% rename from runner/internal/executor/lock.go rename to runner/internal/runner/executor/lock.go diff --git a/runner/internal/executor/logs.go b/runner/internal/runner/executor/logs.go similarity index 91% rename from runner/internal/executor/logs.go rename to runner/internal/runner/executor/logs.go index 807071eeb9..808fc84b1d 100644 --- a/runner/internal/executor/logs.go +++ b/runner/internal/runner/executor/logs.go @@ -3,7 +3,7 @@ package executor import ( "sync" - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) type appendWriter struct { diff --git a/runner/internal/executor/query.go b/runner/internal/runner/executor/query.go similarity index 94% rename from runner/internal/executor/query.go rename to runner/internal/runner/executor/query.go index 6678e5f8d7..f3acbf20ac 100644 --- a/runner/internal/executor/query.go +++ b/runner/internal/runner/executor/query.go @@ -1,7 +1,7 @@ package executor import ( - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) func (ex *RunExecutor) GetJobWsLogsHistory() []schemas.LogEvent { diff --git a/runner/internal/executor/repo.go b/runner/internal/runner/executor/repo.go similarity index 95% rename from runner/internal/executor/repo.go rename to runner/internal/runner/executor/repo.go index dd16092be9..116e4b225d 100644 --- a/runner/internal/executor/repo.go +++ b/runner/internal/runner/executor/repo.go @@ -13,10 +13,10 @@ import ( "github.com/codeclysm/extract/v4" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/repo" - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/common/utils" + "github.com/dstackai/dstack/runner/internal/runner/repo" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) // WriteRepoBlob must be called after SetJob @@ -50,7 +50,7 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error { } var err error - ex.repoDir, err = common.ExpandPath(*ex.jobSpec.RepoDir, ex.jobWorkingDir, ex.jobUser.HomeDir) + ex.repoDir, err = utils.ExpandPath(*ex.jobSpec.RepoDir, ex.jobWorkingDir, ex.jobUser.HomeDir) if err != nil { return fmt.Errorf("expand repo dir path: %w", err) } @@ -236,7 +236,7 @@ func (ex *RunExecutor) restoreRepoDir(ctx context.Context, tmpDir string) error func (ex *RunExecutor) chownRepoDir(ctx context.Context) error { log.Trace(ctx, "Chowning repo dir") - exists, err := common.PathExists(ex.repoDir) + exists, err := utils.PathExists(ex.repoDir) // We consider all errors here non-fatal if err != nil { log.Warning(ctx, "Failed to check if repo dir exists", "err", err) diff --git a/runner/internal/executor/states.go b/runner/internal/runner/executor/states.go similarity index 100% rename from runner/internal/executor/states.go rename to runner/internal/runner/executor/states.go diff --git a/runner/internal/executor/timestamp.go b/runner/internal/runner/executor/timestamp.go similarity index 95% rename from runner/internal/executor/timestamp.go rename to runner/internal/runner/executor/timestamp.go index b1cf0fa2cc..b06d8cf47e 100644 --- a/runner/internal/executor/timestamp.go +++ b/runner/internal/runner/executor/timestamp.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) type MonotonicTimestamp struct { diff --git a/runner/internal/executor/timestamp_test.go b/runner/internal/runner/executor/timestamp_test.go similarity index 100% rename from runner/internal/executor/timestamp_test.go rename to runner/internal/runner/executor/timestamp_test.go diff --git a/runner/internal/executor/user.go b/runner/internal/runner/executor/user.go similarity index 96% rename from runner/internal/executor/user.go rename to runner/internal/runner/executor/user.go index 30affda617..df9f0fe45c 100644 --- a/runner/internal/executor/user.go +++ b/runner/internal/runner/executor/user.go @@ -10,9 +10,9 @@ import ( "strconv" "strings" - linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/common/log" + linuxuser "github.com/dstackai/dstack/runner/internal/runner/linux/user" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) func (ex *RunExecutor) setJobUser(ctx context.Context) error { diff --git a/runner/internal/executor/user_test.go b/runner/internal/runner/executor/user_test.go similarity index 98% rename from runner/internal/executor/user_test.go rename to runner/internal/runner/executor/user_test.go index 2bc6a19d87..c0fc202f2e 100644 --- a/runner/internal/executor/user_test.go +++ b/runner/internal/runner/executor/user_test.go @@ -8,8 +8,8 @@ import ( "github.com/stretchr/testify/require" - linuxuser "github.com/dstackai/dstack/runner/internal/linux/user" - "github.com/dstackai/dstack/runner/internal/schemas" + linuxuser "github.com/dstackai/dstack/runner/internal/runner/linux/user" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) var shouldNotBeCalledErr = errors.New("this function should not be called") diff --git a/runner/internal/linux/capabilities/capabilities_darwin.go b/runner/internal/runner/linux/capabilities/capabilities_darwin.go similarity index 100% rename from runner/internal/linux/capabilities/capabilities_darwin.go rename to runner/internal/runner/linux/capabilities/capabilities_darwin.go diff --git a/runner/internal/linux/capabilities/capabilities_linux.go b/runner/internal/runner/linux/capabilities/capabilities_linux.go similarity index 100% rename from runner/internal/linux/capabilities/capabilities_linux.go rename to runner/internal/runner/linux/capabilities/capabilities_linux.go diff --git a/runner/internal/linux/user/user.go b/runner/internal/runner/linux/user/user.go similarity index 100% rename from runner/internal/linux/user/user.go rename to runner/internal/runner/linux/user/user.go diff --git a/runner/internal/metrics/cgroups.go b/runner/internal/runner/metrics/cgroups.go similarity index 97% rename from runner/internal/metrics/cgroups.go rename to runner/internal/runner/metrics/cgroups.go index 9ce1e54fe6..7ac89db4a1 100644 --- a/runner/internal/metrics/cgroups.go +++ b/runner/internal/runner/metrics/cgroups.go @@ -8,7 +8,7 @@ import ( "os" "strings" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) func getProcessCgroupMountPoint(ctx context.Context, ProcPidMountsPath string) (string, error) { diff --git a/runner/internal/metrics/cgroups_test.go b/runner/internal/runner/metrics/cgroups_test.go similarity index 100% rename from runner/internal/metrics/cgroups_test.go rename to runner/internal/runner/metrics/cgroups_test.go diff --git a/runner/internal/metrics/metrics.go b/runner/internal/runner/metrics/metrics.go similarity index 95% rename from runner/internal/metrics/metrics.go rename to runner/internal/runner/metrics/metrics.go index 26acc2cdf4..56c27a2bb1 100644 --- a/runner/internal/metrics/metrics.go +++ b/runner/internal/runner/metrics/metrics.go @@ -12,14 +12,14 @@ import ( "strings" "time" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/common/gpu" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/runner/schemas" ) type MetricsCollector struct { cgroupMountPoint string - gpuVendor common.GpuVendor + gpuVendor gpu.GpuVendor } func NewMetricsCollector(ctx context.Context) (*MetricsCollector, error) { @@ -29,7 +29,7 @@ func NewMetricsCollector(ctx context.Context) (*MetricsCollector, error) { if err != nil { return nil, fmt.Errorf("get cgroup mount point: %w", err) } - gpuVendor := common.GetGpuVendor() + gpuVendor := gpu.GetGpuVendor() return &MetricsCollector{ cgroupMountPoint: cgroupMountPoint, gpuVendor: gpuVendor, @@ -141,15 +141,15 @@ func (s *MetricsCollector) GetGPUMetrics(ctx context.Context) ([]schemas.GPUMetr var metrics []schemas.GPUMetrics var err error switch s.gpuVendor { - case common.GpuVendorNvidia: + case gpu.GpuVendorNvidia: metrics, err = s.GetNVIDIAGPUMetrics(ctx) - case common.GpuVendorAmd: + case gpu.GpuVendorAmd: metrics, err = s.GetAMDGPUMetrics(ctx) - case common.GpuVendorIntel: + case gpu.GpuVendorIntel: metrics, err = s.GetIntelAcceleratorMetrics(ctx) - case common.GpuVendorTenstorrent: + case gpu.GpuVendorTenstorrent: err = errors.New("tenstorrent metrics not suppored") - case common.GpuVendorNone: + case gpu.GpuVendorNone: // pass } if metrics == nil { diff --git a/runner/internal/metrics/metrics_test.go b/runner/internal/runner/metrics/metrics_test.go similarity index 95% rename from runner/internal/metrics/metrics_test.go rename to runner/internal/runner/metrics/metrics_test.go index 152f31c1b7..3410435ce2 100644 --- a/runner/internal/metrics/metrics_test.go +++ b/runner/internal/runner/metrics/metrics_test.go @@ -4,7 +4,7 @@ import ( "runtime" "testing" - "github.com/dstackai/dstack/runner/internal/schemas" + "github.com/dstackai/dstack/runner/internal/runner/schemas" "github.com/stretchr/testify/assert" ) diff --git a/runner/internal/repo/diff.go b/runner/internal/runner/repo/diff.go similarity index 98% rename from runner/internal/repo/diff.go rename to runner/internal/runner/repo/diff.go index 43e6b2e20f..a7f33cad6c 100644 --- a/runner/internal/repo/diff.go +++ b/runner/internal/runner/repo/diff.go @@ -12,7 +12,7 @@ import ( "github.com/bluekeyes/go-gitdiff/gitdiff" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) func ApplyDiff(ctx context.Context, dir, patch string) error { diff --git a/runner/internal/repo/diff_test.go b/runner/internal/runner/repo/diff_test.go similarity index 100% rename from runner/internal/repo/diff_test.go rename to runner/internal/runner/repo/diff_test.go diff --git a/runner/internal/repo/manager.go b/runner/internal/runner/repo/manager.go similarity index 98% rename from runner/internal/repo/manager.go rename to runner/internal/runner/repo/manager.go index 6e546a0886..baeec40fad 100644 --- a/runner/internal/repo/manager.go +++ b/runner/internal/runner/repo/manager.go @@ -10,7 +10,7 @@ import ( gitssh "github.com/go-git/go-git/v5/plumbing/transport/ssh" "golang.org/x/crypto/ssh" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) type Manager struct { diff --git a/runner/internal/schemas/schemas.go b/runner/internal/runner/schemas/schemas.go similarity index 92% rename from runner/internal/schemas/schemas.go rename to runner/internal/runner/schemas/schemas.go index 10ab62ea95..ca707db761 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/runner/schemas/schemas.go @@ -3,11 +3,21 @@ package schemas import ( "strings" - "github.com/dstackai/dstack/runner/internal/types" + "github.com/dstackai/dstack/runner/internal/common/types" +) + +type JobState string + +const ( + JobStateDone JobState = "done" + JobStateFailed JobState = "failed" + JobStateRunning JobState = "running" + JobStateTerminated JobState = "terminated" + JobStateTerminating JobState = "terminating" ) type JobStateEvent struct { - State types.JobState `json:"state"` + State JobState `json:"state"` Timestamp int64 `json:"timestamp"` TerminationReason types.TerminationReason `json:"termination_reason"` TerminationMessage string `json:"termination_message"` diff --git a/runner/internal/ssh/sshd.go b/runner/internal/runner/ssh/sshd.go similarity index 96% rename from runner/internal/ssh/sshd.go rename to runner/internal/runner/ssh/sshd.go index d46be7e24f..05da8d1401 100644 --- a/runner/internal/ssh/sshd.go +++ b/runner/internal/runner/ssh/sshd.go @@ -11,8 +11,8 @@ import ( "syscall" "time" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/common/utils" ) type SshdManager interface { @@ -203,7 +203,7 @@ func copyHostKey(srcDir string, destDir string, key string) error { func prepareAuthorizedKeysFile(confDir string) (string, error) { // Ensures that the file exists, has correct ownership and permissions, and is empty akPath := path.Join(confDir, "authorized_keys") - if _, err := common.RemoveIfExists(akPath); err != nil { + if _, err := utils.RemoveIfExists(akPath); err != nil { return "", err } file, err := os.OpenFile(akPath, os.O_CREATE|os.O_EXCL|os.O_RDONLY, 0o644) @@ -268,7 +268,7 @@ func prepareLogPath(logDir string) (string, error) { return "", err } logPath := path.Join(logDir, "sshd.log") - if _, err := common.RemoveIfExists(logPath); err != nil { + if _, err := utils.RemoveIfExists(logPath); err != nil { return "", err } return logPath, nil diff --git a/runner/internal/shim/api/handlers.go b/runner/internal/shim/api/handlers.go index dc1be824cb..b3382d0f26 100644 --- a/runner/internal/shim/api/handlers.go +++ b/runner/internal/shim/api/handlers.go @@ -5,8 +5,8 @@ import ( "errors" "net/http" - "github.com/dstackai/dstack/runner/internal/api" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/api" + "github.com/dstackai/dstack/runner/internal/common/log" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" diff --git a/runner/internal/shim/api/handlers_test.go b/runner/internal/shim/api/handlers_test.go index 9bc829a94c..bb19ebbf1b 100644 --- a/runner/internal/shim/api/handlers_test.go +++ b/runner/internal/shim/api/handlers_test.go @@ -6,7 +6,7 @@ import ( "strings" "testing" - common "github.com/dstackai/dstack/runner/internal/api" + commonapi "github.com/dstackai/dstack/runner/internal/common/api" ) func TestHealthcheck(t *testing.T) { @@ -15,7 +15,7 @@ func TestHealthcheck(t *testing.T) { server := NewShimServer(context.Background(), ":12345", "0.0.1.dev2", NewDummyRunner(), nil, nil, nil, nil) - f := common.JSONResponseHandler(server.HealthcheckHandler) + f := commonapi.JSONResponseHandler(server.HealthcheckHandler) f(responseRecorder, request) if responseRecorder.Code != 200 { @@ -39,7 +39,7 @@ func TestTaskSubmit(t *testing.T) { request := httptest.NewRequest("POST", "/api/tasks", strings.NewReader(requestBody)) responseRecorder := httptest.NewRecorder() - firstSubmitPost := common.JSONResponseHandler(server.TaskSubmitHandler) + firstSubmitPost := commonapi.JSONResponseHandler(server.TaskSubmitHandler) firstSubmitPost(responseRecorder, request) if responseRecorder.Code != 200 { t.Errorf("Want status '%d', got '%d'", 200, responseRecorder.Code) @@ -47,7 +47,7 @@ func TestTaskSubmit(t *testing.T) { request = httptest.NewRequest("POST", "/api/tasks", strings.NewReader(requestBody)) responseRecorder = httptest.NewRecorder() - secondSubmitPost := common.JSONResponseHandler(server.TaskSubmitHandler) + secondSubmitPost := commonapi.JSONResponseHandler(server.TaskSubmitHandler) secondSubmitPost(responseRecorder, request) if responseRecorder.Code != 409 { t.Errorf("Want status '%d', got '%d'", 409, responseRecorder.Code) diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 0482db7945..9008aa2efe 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -8,8 +8,8 @@ import ( "reflect" "sync" - "github.com/dstackai/dstack/runner/internal/api" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/api" + "github.com/dstackai/dstack/runner/internal/common/log" "github.com/dstackai/dstack/runner/internal/shim" "github.com/dstackai/dstack/runner/internal/shim/components" "github.com/dstackai/dstack/runner/internal/shim/dcgm" diff --git a/runner/internal/shim/components/utils.go b/runner/internal/shim/components/utils.go index 073832133d..a4456acaa3 100644 --- a/runner/internal/shim/components/utils.go +++ b/runner/internal/shim/components/utils.go @@ -12,8 +12,8 @@ import ( "strings" "time" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/common/utils" ) const downloadTimeout = 10 * time.Minute @@ -90,7 +90,7 @@ func downloadFile(ctx context.Context, url string, path string, mode os.FileMode } func checkDstackComponent(ctx context.Context, name ComponentName, pth string) (status ComponentStatus, version string, err error) { - exists, err := common.PathExists(pth) + exists, err := utils.PathExists(pth) if err != nil { return ComponentStatusError, "", fmt.Errorf("check %s: %w", name, err) } diff --git a/runner/internal/shim/dcgm/exporter.go b/runner/internal/shim/dcgm/exporter.go index f49fb91aee..ed861eb524 100644 --- a/runner/internal/shim/dcgm/exporter.go +++ b/runner/internal/shim/dcgm/exporter.go @@ -17,7 +17,7 @@ import ( "github.com/alexellis/go-execute/v2" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) // Counter represents a single line in counters.csv, see diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 88a7f37c02..6acfb27a51 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -31,12 +31,12 @@ import ( "github.com/docker/go-units" bytesize "github.com/inhies/go-bytesize" - "github.com/dstackai/dstack/runner/consts" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/consts" + "github.com/dstackai/dstack/runner/internal/common/gpu" + "github.com/dstackai/dstack/runner/internal/common/log" + "github.com/dstackai/dstack/runner/internal/common/types" "github.com/dstackai/dstack/runner/internal/shim/backends" "github.com/dstackai/dstack/runner/internal/shim/host" - "github.com/dstackai/dstack/runner/internal/types" ) // TODO: Allow for configuration via cli arguments or environment variables. @@ -55,7 +55,7 @@ type DockerRunner struct { dockerParams DockerParameters dockerInfo dockersystem.Info gpus []host.GpuInfo - gpuVendor common.GpuVendor + gpuVendor gpu.GpuVendor gpuLock *GpuLock tasks TaskStorage } @@ -70,12 +70,12 @@ func NewDockerRunner(ctx context.Context, dockerParams DockerParameters) (*Docke return nil, fmt.Errorf("get docker info: %w", err) } - var gpuVendor common.GpuVendor + var gpuVendor gpu.GpuVendor gpus := host.GetGpuInfo(ctx) if len(gpus) > 0 { gpuVendor = gpus[0].Vendor } else { - gpuVendor = common.GpuVendorNone + gpuVendor = gpu.GpuVendorNone } gpuLock, err := NewGpuLock(gpus) if err != nil { @@ -135,7 +135,7 @@ func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error { log.Error(ctx, "failed to inspect container", "id", containerID, "task", taskID) } else { switch d.gpuVendor { - case common.GpuVendorNvidia: + case gpu.GpuVendorNvidia: deviceRequests := containerFull.HostConfig.DeviceRequests if len(deviceRequests) == 1 { gpuIDs = deviceRequests[0].DeviceIDs @@ -146,13 +146,13 @@ func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error { "id", containerID, "task", taskID, ) } - case common.GpuVendorAmd: + case gpu.GpuVendorAmd: for _, device := range containerFull.HostConfig.Devices { if host.IsRenderNodePath(device.PathOnHost) { gpuIDs = append(gpuIDs, device.PathOnHost) } } - case common.GpuVendorTenstorrent: + case gpu.GpuVendorTenstorrent: for _, device := range containerFull.HostConfig.Devices { if strings.HasPrefix(device.PathOnHost, "/dev/tenstorrent/") { // Extract the device ID from the path @@ -160,14 +160,14 @@ func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error { gpuIDs = append(gpuIDs, deviceID) } } - case common.GpuVendorIntel: + case gpu.GpuVendorIntel: for _, envVar := range containerFull.Config.Env { if indices, found := strings.CutPrefix(envVar, "HABANA_VISIBLE_DEVICES="); found { gpuIDs = strings.Split(indices, ",") break } } - case common.GpuVendorNone: + case gpu.GpuVendorNone: gpuIDs = []string{} } ports = extractPorts(ctx, containerFull.NetworkSettings.Ports) @@ -1024,12 +1024,12 @@ func configureGpuDevices(hostConfig *container.HostConfig, gpuDevices []GPUDevic } } -func configureGpus(config *container.Config, hostConfig *container.HostConfig, vendor common.GpuVendor, ids []string) { +func configureGpus(config *container.Config, hostConfig *container.HostConfig, vendor gpu.GpuVendor, ids []string) { // NVIDIA: ids are identifiers reported by nvidia-smi, GPU- strings // AMD: ids are DRI render node paths, e.g., /dev/dri/renderD128 // Tenstorrent: ids are device indices to be used with /dev/tenstorrent/ switch vendor { - case common.GpuVendorNvidia: + case gpu.GpuVendorNvidia: hostConfig.DeviceRequests = append( hostConfig.DeviceRequests, container.DeviceRequest{ @@ -1040,7 +1040,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v DeviceIDs: ids, }, ) - case common.GpuVendorAmd: + case gpu.GpuVendorAmd: // All options are listed here: https://hub.docker.com/r/rocm/pytorch // Only --device are mandatory, other seem to be performance-related. // --device=/dev/kfd @@ -1070,7 +1070,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v // --security-opt=seccomp=unconfined hostConfig.SecurityOpt = append(hostConfig.SecurityOpt, "seccomp=unconfined") // TODO: in addition, for non-root user, --group-add=video, and possibly --group-add=render, are required. - case common.GpuVendorTenstorrent: + case gpu.GpuVendorTenstorrent: // For Tenstorrent, simply add each device for _, id := range ids { devicePath := fmt.Sprintf("/dev/tenstorrent/%s", id) @@ -1091,7 +1091,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v Target: "/dev/hugepages-1G", }) } - case common.GpuVendorIntel: + case gpu.GpuVendorIntel: // All options are listed here: // https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html // --runtime=habana @@ -1102,7 +1102,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v hostConfig.CapAdd = append(hostConfig.CapAdd, "SYS_NICE") // -e HABANA_VISIBLE_DEVICES=0,1,... config.Env = append(config.Env, fmt.Sprintf("HABANA_VISIBLE_DEVICES=%s", strings.Join(ids, ","))) - case common.GpuVendorNone: + case gpu.GpuVendorNone: // nothing to do } } diff --git a/runner/internal/shim/host/gpu.go b/runner/internal/shim/host/gpu.go index b2b2135efc..0452f1ff46 100644 --- a/runner/internal/shim/host/gpu.go +++ b/runner/internal/shim/host/gpu.go @@ -13,8 +13,8 @@ import ( execute "github.com/alexellis/go-execute/v2" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/gpu" + "github.com/dstackai/dstack/runner/internal/common/log" ) const ( @@ -23,7 +23,7 @@ const ( ) type GpuInfo struct { - Vendor common.GpuVendor + Vendor gpu.GpuVendor Name string Vram int // MiB // NVIDIA: uuid field from nvidia-smi, "globally unique immutable alphanumeric identifier of the GPU", @@ -43,16 +43,16 @@ type GpuInfo struct { } func GetGpuInfo(ctx context.Context) []GpuInfo { - switch gpuVendor := common.GetGpuVendor(); gpuVendor { - case common.GpuVendorNvidia: + switch gpuVendor := gpu.GetGpuVendor(); gpuVendor { + case gpu.GpuVendorNvidia: return getNvidiaGpuInfo(ctx) - case common.GpuVendorAmd: + case gpu.GpuVendorAmd: return getAmdGpuInfo(ctx) - case common.GpuVendorIntel: + case gpu.GpuVendorIntel: return getIntelGpuInfo(ctx) - case common.GpuVendorTenstorrent: + case gpu.GpuVendorTenstorrent: return getTenstorrentGpuInfo(ctx) - case common.GpuVendorNone: + case gpu.GpuVendorNone: return []GpuInfo{} } return []GpuInfo{} @@ -99,7 +99,7 @@ func getNvidiaGpuInfo(ctx context.Context) []GpuInfo { vram = 0 } gpus = append(gpus, GpuInfo{ - Vendor: common.GpuVendorNvidia, + Vendor: gpu.GpuVendorNvidia, Name: strings.TrimSpace(record[0]), Vram: vram, ID: strings.TrimSpace(record[2]), @@ -170,7 +170,7 @@ func getAmdGpuInfo(ctx context.Context) []GpuInfo { continue } gpus = append(gpus, GpuInfo{ - Vendor: common.GpuVendorAmd, + Vendor: gpu.GpuVendorAmd, Name: amdGpu.Asic.Name, Vram: amdGpu.Vram.Size.Value, RenderNodePath: renderNodePath, @@ -233,7 +233,7 @@ func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo { // Create new GPU entry for "L" device lDeviceMap[uniqueID] = &GpuInfo{ - Vendor: common.GpuVendorTenstorrent, + Vendor: gpu.GpuVendorTenstorrent, Name: name, Vram: baseVram, ID: boardID, @@ -304,7 +304,7 @@ func getGpusFromTtSmiSnapshot(snapshot *ttSmiSnapshot) []GpuInfo { if !existingGpu { // Create new GPU entry lDeviceMap[uniqueID] = &GpuInfo{ - Vendor: common.GpuVendorTenstorrent, + Vendor: gpu.GpuVendorTenstorrent, Name: boardType, Vram: baseVram, ID: boardID, @@ -423,7 +423,7 @@ func getIntelGpuInfo(ctx context.Context) []GpuInfo { vram = 0 } gpus = append(gpus, GpuInfo{ - Vendor: common.GpuVendorIntel, + Vendor: gpu.GpuVendorIntel, Name: strings.TrimSpace(record[0]), Vram: vram, Index: strings.TrimSpace(record[2]), diff --git a/runner/internal/shim/host/gpu_test.go b/runner/internal/shim/host/gpu_test.go index 2f8eda8e2e..9facf9992a 100644 --- a/runner/internal/shim/host/gpu_test.go +++ b/runner/internal/shim/host/gpu_test.go @@ -7,7 +7,7 @@ import ( "strconv" "testing" - "github.com/dstackai/dstack/runner/internal/common" + "github.com/dstackai/dstack/runner/internal/common/gpu" ) func loadTestData(filename string) ([]byte, error) { @@ -172,7 +172,7 @@ func TestGetGpusFromTtSmiSnapshot(t *testing.T) { expectedGpus := []GpuInfo{ { - Vendor: common.GpuVendorTenstorrent, + Vendor: gpu.GpuVendorTenstorrent, Name: "n150", Vram: 12 * 1024, ID: "100018611902010", @@ -222,19 +222,19 @@ func TestGetGpusFromTtSmiSnapshotMultipleDevices(t *testing.T) { } for boardID, expected := range expectedGpus { - gpu, exists := gpusByID[boardID] + gpu_, exists := gpusByID[boardID] if !exists { t.Errorf("Expected GPU with board_id %s not found", boardID) continue } - if gpu.Name != expected.name { - t.Errorf("GPU %s: name = %s, want %s", boardID, gpu.Name, expected.name) + if gpu_.Name != expected.name { + t.Errorf("GPU %s: name = %s, want %s", boardID, gpu_.Name, expected.name) } - if gpu.Vram != expected.vram { - t.Errorf("GPU %s: VRAM = %d, want %d", boardID, gpu.Vram, expected.vram) + if gpu_.Vram != expected.vram { + t.Errorf("GPU %s: VRAM = %d, want %d", boardID, gpu_.Vram, expected.vram) } - if gpu.Vendor != common.GpuVendorTenstorrent { - t.Errorf("GPU %s: vendor = %v, want %v", boardID, gpu.Vendor, common.GpuVendorTenstorrent) + if gpu_.Vendor != gpu.GpuVendorTenstorrent { + t.Errorf("GPU %s: vendor = %v, want %v", boardID, gpu_.Vendor, gpu.GpuVendorTenstorrent) } } } @@ -263,25 +263,25 @@ func TestGetGpusFromTtSmiSnapshotGalaxy(t *testing.T) { actualTotalVram := 0 // Verify all GPUs have the correct properties - for i, gpu := range gpus { - if gpu.Vendor != common.GpuVendorTenstorrent { - t.Errorf("GPU[%d] vendor = %v, want %v", i, gpu.Vendor, common.GpuVendorTenstorrent) + for i, gpu_ := range gpus { + if gpu_.Vendor != gpu.GpuVendorTenstorrent { + t.Errorf("GPU[%d] vendor = %v, want %v", i, gpu_.Vendor, gpu.GpuVendorTenstorrent) } - if gpu.Name != "tt-galaxy-wh" { - t.Errorf("GPU[%d] name = %s, want tt-galaxy-wh", i, gpu.Name) + if gpu_.Name != "tt-galaxy-wh" { + t.Errorf("GPU[%d] name = %s, want tt-galaxy-wh", i, gpu_.Name) } - if gpu.ID != "100035100000000" { - t.Errorf("GPU[%d] ID = %s, want 100035100000000", i, gpu.ID) + if gpu_.ID != "100035100000000" { + t.Errorf("GPU[%d] ID = %s, want 100035100000000", i, gpu_.ID) } - if gpu.Vram != 12*1024 { - t.Errorf("GPU[%d] VRAM = %d, want %d", i, gpu.Vram, 12*1024) + if gpu_.Vram != 12*1024 { + t.Errorf("GPU[%d] VRAM = %d, want %d", i, gpu_.Vram, 12*1024) } // Verify indices are sequential (0, 1, 2, ..., 31) expectedIndex := strconv.Itoa(i) - if gpu.Index != expectedIndex { - t.Errorf("GPU[%d] index = %s, want %s", i, gpu.Index, expectedIndex) + if gpu_.Index != expectedIndex { + t.Errorf("GPU[%d] index = %s, want %s", i, gpu_.Index, expectedIndex) } - actualTotalVram += gpu.Vram + actualTotalVram += gpu_.Vram } // Verify total VRAM is 384GB diff --git a/runner/internal/shim/host/host.go b/runner/internal/shim/host/host.go index bc54a407c7..84d15d1ae8 100644 --- a/runner/internal/shim/host/host.go +++ b/runner/internal/shim/host/host.go @@ -9,7 +9,7 @@ import ( "github.com/shirou/gopsutil/v4/mem" "golang.org/x/sys/unix" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) func GetCpuCount(ctx context.Context) int { diff --git a/runner/internal/shim/host_info.go b/runner/internal/shim/host_info.go index ea717e112c..2634d939c3 100644 --- a/runner/internal/shim/host_info.go +++ b/runner/internal/shim/host_info.go @@ -7,18 +7,18 @@ import ( "os" "path/filepath" - "github.com/dstackai/dstack/runner/internal/common" + "github.com/dstackai/dstack/runner/internal/common/gpu" ) type hostInfo struct { - GpuVendor common.GpuVendor `json:"gpu_vendor"` - GpuName string `json:"gpu_name"` - GpuMemory int `json:"gpu_memory"` // MiB - GpuCount int `json:"gpu_count"` - Addresses []string `json:"addresses"` - DiskSize uint64 `json:"disk_size"` // bytes - NumCPUs int `json:"cpus"` - Memory uint64 `json:"memory"` // bytes + GpuVendor gpu.GpuVendor `json:"gpu_vendor"` + GpuName string `json:"gpu_name"` + GpuMemory int `json:"gpu_memory"` // MiB + GpuCount int `json:"gpu_count"` + Addresses []string `json:"addresses"` + DiskSize uint64 `json:"disk_size"` // bytes + NumCPUs int `json:"cpus"` + Memory uint64 `json:"memory"` // bytes } func WriteHostInfo(dir string, resources Resources) error { @@ -28,7 +28,7 @@ func WriteHostInfo(dir string, resources Resources) error { return err } - gpuVendor := common.GpuVendorNone + gpuVendor := gpu.GpuVendorNone gpuCount := 0 gpuMemory := 0 gpuName := "" diff --git a/runner/internal/shim/resources.go b/runner/internal/shim/resources.go index bcc589f272..e0d888873b 100644 --- a/runner/internal/shim/resources.go +++ b/runner/internal/shim/resources.go @@ -6,8 +6,8 @@ import ( "fmt" "sync" - "github.com/dstackai/dstack/runner/internal/common" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/gpu" + "github.com/dstackai/dstack/runner/internal/common/log" "github.com/dstackai/dstack/runner/internal/shim/host" ) @@ -33,21 +33,21 @@ func NewGpuLock(gpus []host.GpuInfo) (*GpuLock, error) { lock := make(map[string]bool, len(gpus)) if len(gpus) > 0 { vendor := gpus[0].Vendor - for _, gpu := range gpus { - if gpu.Vendor != vendor { + for _, gpu_ := range gpus { + if gpu_.Vendor != vendor { return nil, errors.New("multiple GPU vendors detected") } var resourceID string switch vendor { - case common.GpuVendorNvidia: - resourceID = gpu.ID - case common.GpuVendorAmd: - resourceID = gpu.RenderNodePath - case common.GpuVendorTenstorrent: - resourceID = gpu.Index - case common.GpuVendorIntel: - resourceID = gpu.Index - case common.GpuVendorNone: + case gpu.GpuVendorNvidia: + resourceID = gpu_.ID + case gpu.GpuVendorAmd: + resourceID = gpu_.RenderNodePath + case gpu.GpuVendorTenstorrent: + resourceID = gpu_.Index + case gpu.GpuVendorIntel: + resourceID = gpu_.Index + case gpu.GpuVendorNone: return nil, fmt.Errorf("unexpected GPU vendor %s", vendor) default: return nil, fmt.Errorf("unexpected GPU vendor %s", vendor) diff --git a/runner/internal/shim/resources_test.go b/runner/internal/shim/resources_test.go index f582d14cf2..424ff55b41 100644 --- a/runner/internal/shim/resources_test.go +++ b/runner/internal/shim/resources_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/dstackai/dstack/runner/internal/common" + "github.com/dstackai/dstack/runner/internal/common/gpu" "github.com/dstackai/dstack/runner/internal/shim/host" "github.com/stretchr/testify/assert" ) @@ -18,8 +18,8 @@ func TestNewGpuLock_NoGpus(t *testing.T) { func TestNewGpuLock_NvidiaGpus(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, } gl, err := NewGpuLock(gpus) assert.Nil(t, err) @@ -32,8 +32,8 @@ func TestNewGpuLock_NvidiaGpus(t *testing.T) { func TestNewGpuLock_AmdGpus(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorAmd, RenderNodePath: "/dev/dri/renderD128"}, - {Vendor: common.GpuVendorAmd, RenderNodePath: "/dev/dri/renderD129"}, + {Vendor: gpu.GpuVendorAmd, RenderNodePath: "/dev/dri/renderD128"}, + {Vendor: gpu.GpuVendorAmd, RenderNodePath: "/dev/dri/renderD129"}, } gl, err := NewGpuLock(gpus) assert.Nil(t, err) @@ -46,8 +46,8 @@ func TestNewGpuLock_AmdGpus(t *testing.T) { func TestNewGpuLock_ErrorMultipleVendors(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorAmd}, - {Vendor: common.GpuVendorNvidia}, + {Vendor: gpu.GpuVendorAmd}, + {Vendor: gpu.GpuVendorNvidia}, } gl, err := NewGpuLock(gpus) assert.Nil(t, gl) @@ -68,9 +68,9 @@ func TestGpuLock_Acquire_ErrorBadCount(t *testing.T) { func TestGpuLock_Acquire_All_Available(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-c0de"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-c0de"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-f00d"] = true @@ -84,8 +84,8 @@ func TestGpuLock_Acquire_All_Available(t *testing.T) { func TestGpuLock_Acquire_All_NoneAvailable(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true @@ -104,10 +104,10 @@ func TestGpuLock_Acquire_All_NoGpus(t *testing.T) { func TestGpuLock_Acquire_Count_OK(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-c0de"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-cafe"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-c0de"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-cafe"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-f00d"] = true @@ -128,8 +128,8 @@ func TestGpuLock_Acquire_Count_OK(t *testing.T) { func TestGpuLock_Acquire_Count_ErrNoCapacity(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-f00d"] = true @@ -142,9 +142,9 @@ func TestGpuLock_Acquire_Count_ErrNoCapacity(t *testing.T) { func TestGpuLock_Lock(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-c0de"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-c0de"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true @@ -162,8 +162,8 @@ func TestGpuLock_Lock(t *testing.T) { func TestGpuLock_Lock_Nil(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true @@ -176,9 +176,9 @@ func TestGpuLock_Lock_Nil(t *testing.T) { func TestGpuLock_Release(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-c0de"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-c0de"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true @@ -196,8 +196,8 @@ func TestGpuLock_Release(t *testing.T) { func TestGpuLock_Release_Nil(t *testing.T) { gpus := []host.GpuInfo{ - {Vendor: common.GpuVendorNvidia, ID: "GPU-beef"}, - {Vendor: common.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: gpu.GpuVendorNvidia, ID: "GPU-f00d"}, } gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true diff --git a/runner/internal/shim/task.go b/runner/internal/shim/task.go index f1d67b785c..d2fef7e02d 100644 --- a/runner/internal/shim/task.go +++ b/runner/internal/shim/task.go @@ -6,7 +6,7 @@ import ( "fmt" "sync" - "github.com/dstackai/dstack/runner/internal/log" + "github.com/dstackai/dstack/runner/internal/common/log" ) type TaskStatus string