diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 4eca1aee7b..c05430368a 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -35,6 +35,7 @@ the SSH server and handling the connection proxy. var autoStartCluster bool var userKnownHostsFile string var liteswap string + var skipSettingsCheck bool cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") @@ -64,6 +65,9 @@ the SSH server and handling the connection proxy. cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)") cmd.Flags().MarkHidden("liteswap") + cmd.Flags().BoolVar(&skipSettingsCheck, "skip-settings-check", false, "Skip checking and updating IDE settings") + cmd.Flags().MarkHidden("skip-settings-check") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -113,6 +117,7 @@ the SSH server and handling the connection proxy. ClientPrivateKeyName: clientPrivateKeyName, UserKnownHostsFile: userKnownHostsFile, Liteswap: liteswap, + SkipSettingsCheck: skipSettingsCheck, AdditionalArgs: args, } return client.Run(ctx, wsClient, opts) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 940f792f0e..d751833db1 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -20,6 +20,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" "github.com/databricks/cli/experimental/ssh/internal/sshconfig" + "github.com/databricks/cli/experimental/ssh/internal/vscode" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -92,6 +93,8 @@ type ClientOptions struct { UserKnownHostsFile string // Liteswap header value for traffic routing (dev/test only). Liteswap string + // If true, skip checking and updating IDE settings. + SkipSettingsCheck bool } func (o *ClientOptions) IsServerlessMode() bool { @@ -206,6 +209,26 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt cmdio.LogString(ctx, "Using SSH key: "+keyPath) cmdio.LogString(ctx, fmt.Sprintf("Secrets scope: %s, key name: %s", secretScopeName, opts.ClientPublicKeyName)) + // Check and update IDE settings for serverless mode, where we must set up + // desired server ports (or socket connection mode) for the connection to go through + // (as the majority of the localhost ports on the remote side are blocked by iptable rules). + // Plus the platform (always linux), and extensions (python and jupyter), to make the initial experience smoother. + if opts.IDE != "" && opts.IsServerlessMode() && !opts.ProxyMode && !opts.SkipSettingsCheck && cmdio.IsPromptSupported(ctx) { + err = vscode.CheckAndUpdateSettings(ctx, opts.IDE, opts.ConnectionName) + if err != nil { + cmdio.LogString(ctx, fmt.Sprintf("Failed to update IDE settings: %v", err)) + cmdio.LogString(ctx, vscode.GetManualInstructions(opts.IDE, opts.ConnectionName)) + cmdio.LogString(ctx, "Use --skip-settings-check to bypass IDE settings verification.") + shouldProceed, promptErr := cmdio.AskYesOrNo(ctx, "Do you want to proceed with the connection?") + if promptErr != nil { + return fmt.Errorf("failed to prompt user: %w", promptErr) + } + if !shouldProceed { + return errors.New("aborted: IDE settings need to be updated manually, user declined to proceed") + } + } + } + var userName string var serverPort int var clusterID string diff --git a/experimental/ssh/internal/vscode/settings.go b/experimental/ssh/internal/vscode/settings.go new file mode 100644 index 0000000000..43f3840c8b --- /dev/null +++ b/experimental/ssh/internal/vscode/settings.go @@ -0,0 +1,316 @@ +package vscode + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/log" + "github.com/tidwall/jsonc" +) + +const ( + portRange = "4000-4005" + remotePlatform = "linux" + pythonExtension = "ms-python.python" + jupyterExtension = "ms-toolsai.jupyter" + serverPickPortsKey = "remote.SSH.serverPickPortsFromRange" + remotePlatformKey = "remote.SSH.remotePlatform" + defaultExtensionsKey = "remote.SSH.defaultExtensions" + listenOnSocketKey = "remote.SSH.remoteServerListenOnSocket" + vscodeIDE = "vscode" + cursorIDE = "cursor" + vscodeName = "VS Code" + cursorName = "Cursor" +) + +func getIDEName(ide string) string { + if ide == cursorIDE { + return cursorName + } + return vscodeName +} + +type missingSettings struct { + portRange bool + platform bool + listenOnSocket bool + extensions []string +} + +func (m *missingSettings) isEmpty() bool { + return !m.portRange && !m.platform && !m.listenOnSocket && len(m.extensions) == 0 +} + +func CheckAndUpdateSettings(ctx context.Context, ide, connectionName string) error { + if !cmdio.IsPromptSupported(ctx) { + log.Debugf(ctx, "Skipping IDE settings check: prompts not supported") + return nil + } + + settingsPath, err := getDefaultSettingsPath(ctx, ide) + if err != nil { + return fmt.Errorf("failed to get settings path: %w", err) + } + + settings, err := loadSettings(settingsPath) + if err != nil { + if os.IsNotExist(err) { + return handleMissingFile(ctx, ide, connectionName, settingsPath) + } + return fmt.Errorf("failed to load settings: %w", err) + } + + missing := validateSettings(settings, connectionName) + if missing.isEmpty() { + log.Debugf(ctx, "IDE settings already correct for %s", connectionName) + return nil + } + + shouldUpdate, err := promptUserForUpdate(ctx, ide, connectionName, missing) + if err != nil { + return fmt.Errorf("failed to prompt user: %w", err) + } + if !shouldUpdate { + log.Infof(ctx, "Skipping IDE settings update") + return nil + } + + if err := backupSettings(ctx, settingsPath); err != nil { + log.Warnf(ctx, "Failed to backup settings: %v. Continuing with update.", err) + } + + updateSettings(settings, connectionName, missing) + + if err := saveSettings(settingsPath, settings); err != nil { + return fmt.Errorf("failed to save settings: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Updated %s settings for '%s'", getIDEName(ide), connectionName)) + return nil +} + +func getDefaultSettingsPath(ctx context.Context, ide string) (string, error) { + home, err := env.UserHomeDir(ctx) + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + + appName := "Code" + if ide == cursorIDE { + appName = "Cursor" + } + + var settingsDir string + switch runtime.GOOS { + case "darwin": + settingsDir = filepath.Join(home, "Library", "Application Support", appName, "User") + case "windows": + appData := env.Get(ctx, "APPDATA") + if appData == "" { + appData = filepath.Join(home, "AppData", "Roaming") + } + settingsDir = filepath.Join(appData, appName, "User") + case "linux": + settingsDir = filepath.Join(home, ".config", appName, "User") + default: + return "", fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } + + return filepath.Join(settingsDir, "settings.json"), nil +} + +func loadSettings(path string) (map[string]any, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + // VS Code/Cursor settings files are in JSONC format (JSON with comments). + cleanJSON := jsonc.ToJSON(data) + var settings map[string]any + if err := json.Unmarshal(cleanJSON, &settings); err != nil { + return nil, fmt.Errorf("failed to parse settings JSON: %w", err) + } + return settings, nil +} + +func hasCorrectPortRange(settings map[string]any, connectionName string) bool { + portRangeObj, ok := settings[serverPickPortsKey].(map[string]any) + if !ok { + return false + } + val, ok := portRangeObj[connectionName].(string) + return ok && val == portRange +} + +func hasCorrectPlatform(settings map[string]any, connectionName string) bool { + platformObj, ok := settings[remotePlatformKey].(map[string]any) + if !ok { + return false + } + val, ok := platformObj[connectionName].(string) + return ok && val == remotePlatform +} + +func hasCorrectListenOnSocket(settings map[string]any) bool { + val, ok := settings[listenOnSocketKey].(bool) + return ok && val +} + +func getMissingExtensions(settings map[string]any) []string { + requiredExtensions := []string{pythonExtension, jupyterExtension} + + extArray, ok := settings[defaultExtensionsKey].([]any) + if !ok { + return requiredExtensions + } + + existingExts := make(map[string]bool) + for _, ext := range extArray { + if extStr, ok := ext.(string); ok { + existingExts[extStr] = true + } + } + + var missing []string + for _, reqExt := range requiredExtensions { + if !existingExts[reqExt] { + missing = append(missing, reqExt) + } + } + return missing +} + +func validateSettings(settings map[string]any, connectionName string) *missingSettings { + return &missingSettings{ + portRange: !hasCorrectPortRange(settings, connectionName), + platform: !hasCorrectPlatform(settings, connectionName), + listenOnSocket: !hasCorrectListenOnSocket(settings), + extensions: getMissingExtensions(settings), + } +} + +func promptUserForUpdate(ctx context.Context, ide, connectionName string, _ *missingSettings) (bool, error) { + question := fmt.Sprintf("%s settings are missing required configuration for '%s'. Update settings?", getIDEName(ide), connectionName) + return cmdio.AskYesOrNo(ctx, question) +} + +func handleMissingFile(ctx context.Context, ide, connectionName, settingsPath string) error { + question := fmt.Sprintf("%s settings not found. Create settings with recommended configuration for '%s'?", getIDEName(ide), connectionName) + shouldCreate, err := cmdio.AskYesOrNo(ctx, question) + if err != nil { + return fmt.Errorf("failed to prompt user: %w", err) + } + if !shouldCreate { + log.Infof(ctx, "Skipping IDE settings creation") + return nil + } + + settingsDir := filepath.Dir(settingsPath) + if err := os.MkdirAll(settingsDir, 0o755); err != nil { + return fmt.Errorf("failed to create settings directory: %w", err) + } + + settings := make(map[string]any) + missing := &missingSettings{ + portRange: true, + platform: true, + listenOnSocket: true, + extensions: []string{pythonExtension, jupyterExtension}, + } + updateSettings(settings, connectionName, missing) + + if err := saveSettings(settingsPath, settings); err != nil { + return fmt.Errorf("failed to save settings: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Created %s settings at %s", getIDEName(ide), filepath.ToSlash(settingsPath))) + return nil +} + +func backupSettings(ctx context.Context, path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + if len(data) == 0 { + return nil + } + + backupPath := path + ".bak" + log.Infof(ctx, "Backing up settings to %s", filepath.ToSlash(backupPath)) + return os.WriteFile(backupPath, data, 0o600) +} + +func getOrDefault[T any](settings map[string]any, key string, defaultVal T) T { + if existing, ok := settings[key].(T); ok { + return existing + } + return defaultVal +} + +func updateSettings(settings map[string]any, connectionName string, missing *missingSettings) { + if missing.portRange { + portsConfig := getOrDefault(settings, serverPickPortsKey, make(map[string]any)) + portsConfig[connectionName] = portRange + settings[serverPickPortsKey] = portsConfig + } + + if missing.platform { + platformConfig := getOrDefault(settings, remotePlatformKey, make(map[string]any)) + platformConfig[connectionName] = remotePlatform + settings[remotePlatformKey] = platformConfig + } + + if missing.listenOnSocket { + settings[listenOnSocketKey] = true + } + + if len(missing.extensions) > 0 { + extArray := getOrDefault(settings, defaultExtensionsKey, []any{}) + existing := make(map[string]bool) + for _, ext := range extArray { + if extStr, ok := ext.(string); ok { + existing[extStr] = true + } + } + for _, ext := range missing.extensions { + if !existing[ext] { + extArray = append(extArray, ext) + } + } + settings[defaultExtensionsKey] = extArray + } +} + +func saveSettings(path string, settings map[string]any) error { + data, err := json.MarshalIndent(settings, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal settings: %w", err) + } + + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("failed to write settings file: %w", err) + } + + return nil +} + +func GetManualInstructions(ide, connectionName string) string { + return fmt.Sprintf( + "To ensure the remote connection works as expected, manually add these settings to your %s settings.json:\n"+ + " \"%s\": {\"%s\": \"%s\"},\n"+ + " \"%s\": {\"%s\": \"%s\"},\n"+ + " \"%s\": true,\n"+ + " \"%s\": [\"%s\", \"%s\"]", + getIDEName(ide), + serverPickPortsKey, connectionName, portRange, + remotePlatformKey, connectionName, remotePlatform, + listenOnSocketKey, + defaultExtensionsKey, pythonExtension, jupyterExtension) +} diff --git a/experimental/ssh/internal/vscode/settings_test.go b/experimental/ssh/internal/vscode/settings_test.go new file mode 100644 index 0000000000..2f6eb135e4 --- /dev/null +++ b/experimental/ssh/internal/vscode/settings_test.go @@ -0,0 +1,452 @@ +package vscode + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetDefaultSettingsPath_VSCode_Linux(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Skipping Linux-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "HOME", "/home/testuser") + + path, err := getDefaultSettingsPath(ctx, vscodeIDE) + require.NoError(t, err) + assert.Equal(t, "/home/testuser/.config/Code/User/settings.json", path) +} + +func TestGetDefaultSettingsPath_Cursor_Linux(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Skipping Linux-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "HOME", "/home/testuser") + + path, err := getDefaultSettingsPath(ctx, cursorIDE) + require.NoError(t, err) + assert.Equal(t, "/home/testuser/.config/Cursor/User/settings.json", path) +} + +func TestGetDefaultSettingsPath_VSCode_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Skipping Windows-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "APPDATA", `C:\Users\testuser\AppData\Roaming`) + + path, err := getDefaultSettingsPath(ctx, vscodeIDE) + require.NoError(t, err) + assert.Equal(t, `C:\Users\testuser\AppData\Roaming\Code\User\settings.json`, path) +} + +func TestGetDefaultSettingsPath_Cursor_Windows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Skipping Windows-specific test") + } + + ctx := context.Background() + ctx = env.Set(ctx, "APPDATA", `C:\Users\testuser\AppData\Roaming`) + + path, err := getDefaultSettingsPath(ctx, cursorIDE) + require.NoError(t, err) + assert.Equal(t, `C:\Users\testuser\AppData\Roaming\Cursor\User\settings.json`, path) +} + +func TestLoadSettings_Valid(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + settingsData := map[string]any{ + "editor.fontSize": 14, + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "test-conn": "4000-4005", + }, + } + data, err := json.Marshal(settingsData) + require.NoError(t, err) + err = os.WriteFile(settingsPath, data, 0o600) + require.NoError(t, err) + + settings, err := loadSettings(settingsPath) + require.NoError(t, err) + assert.InDelta(t, float64(14), settings["editor.fontSize"], 0.01) + assert.Contains(t, settings, "remote.SSH.serverPickPortsFromRange") +} + +func TestLoadSettings_Invalid(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + err := os.WriteFile(settingsPath, []byte("invalid json {"), 0o600) + require.NoError(t, err) + + _, err = loadSettings(settingsPath) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse settings JSON") +} + +func TestLoadSettings_WithComments(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + // JSONC format with comments and trailing commas (typical VS Code settings) + settingsData := `{ + // Editor settings + "editor.fontSize": 14, + /* Connection settings */ + "remote.SSH.serverPickPortsFromRange": { + "test-conn": "4000-4005" // Port range for SSH + }, + "remote.SSH.remotePlatform": { + "test-conn": "linux", // trailing comma + } + }` + err := os.WriteFile(settingsPath, []byte(settingsData), 0o600) + require.NoError(t, err) + + settings, err := loadSettings(settingsPath) + require.NoError(t, err) + assert.InDelta(t, float64(14), settings["editor.fontSize"], 0.01) + assert.Contains(t, settings, "remote.SSH.serverPickPortsFromRange") + + portRangeObj := settings["remote.SSH.serverPickPortsFromRange"].(map[string]any) + assert.Equal(t, "4000-4005", portRangeObj["test-conn"]) +} + +func TestLoadSettings_NotExists(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "nonexistent.json") + + _, err := loadSettings(settingsPath) + assert.Error(t, err) + assert.True(t, os.IsNotExist(err)) +} + +func TestValidateSettings_Complete(t *testing.T) { + settings := map[string]any{ + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "test-conn": "4000-4005", + }, + "remote.SSH.remotePlatform": map[string]any{ + "test-conn": "linux", + }, + "remote.SSH.remoteServerListenOnSocket": true, + "remote.SSH.defaultExtensions": []any{ + "ms-python.python", + "ms-toolsai.jupyter", + }, + } + + missing := validateSettings(settings, "test-conn") + assert.True(t, missing.isEmpty()) +} + +func TestValidateSettings_Missing(t *testing.T) { + settings := map[string]any{} + + missing := validateSettings(settings, "test-conn") + assert.False(t, missing.isEmpty()) + assert.True(t, missing.portRange) + assert.True(t, missing.platform) + assert.Equal(t, []string{"ms-python.python", "ms-toolsai.jupyter"}, missing.extensions) +} + +func TestValidateSettings_IncorrectValues(t *testing.T) { + settings := map[string]any{ + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "test-conn": "5000-5005", // Wrong port range + }, + "remote.SSH.remotePlatform": map[string]any{ + "test-conn": "windows", // Wrong platform + }, + "remote.SSH.defaultExtensions": []any{ + "ms-python.python", // Missing jupyter + }, + } + + missing := validateSettings(settings, "test-conn") + assert.False(t, missing.isEmpty()) + assert.True(t, missing.portRange) + assert.True(t, missing.platform) + assert.Equal(t, []string{"ms-toolsai.jupyter"}, missing.extensions) +} + +func TestValidateSettings_MissingConnection(t *testing.T) { + settings := map[string]any{ + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "other-conn": "4000-4005", + }, + "remote.SSH.remotePlatform": map[string]any{ + "other-conn": "linux", + }, + "remote.SSH.defaultExtensions": []any{ + "ms-python.python", + "ms-toolsai.jupyter", + }, + } + + // Validating for a different connection should show port and platform as missing + missing := validateSettings(settings, "test-conn") + assert.False(t, missing.isEmpty()) + assert.True(t, missing.portRange) + assert.True(t, missing.platform) + assert.Empty(t, missing.extensions) // Extensions are global, so they're present +} + +func TestUpdateSettings_PreserveExistingConnections(t *testing.T) { + settings := map[string]any{ + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "conn-a": "5000-5005", + "conn-b": "6000-6005", + }, + "remote.SSH.remotePlatform": map[string]any{ + "conn-a": "linux", + "conn-b": "darwin", + }, + "remote.SSH.defaultExtensions": []any{ + "other.extension", + }, + } + + missing := &missingSettings{ + portRange: true, + platform: true, + extensions: []string{"ms-python.python", "ms-toolsai.jupyter"}, + } + + updateSettings(settings, "conn-c", missing) + + // Check that new connection was added + portRangeObj := settings["remote.SSH.serverPickPortsFromRange"].(map[string]any) + assert.Equal(t, "4000-4005", portRangeObj["conn-c"]) + + platformObj := settings["remote.SSH.remotePlatform"].(map[string]any) + assert.Equal(t, "linux", platformObj["conn-c"]) + + // Check that existing connections were preserved + assert.Equal(t, "5000-5005", portRangeObj["conn-a"]) + assert.Equal(t, "6000-6005", portRangeObj["conn-b"]) + assert.Equal(t, "linux", platformObj["conn-a"]) + assert.Equal(t, "darwin", platformObj["conn-b"]) + + // Check that extensions were merged + extArray := settings["remote.SSH.defaultExtensions"].([]any) + assert.Len(t, extArray, 3) + assert.Contains(t, extArray, "other.extension") + assert.Contains(t, extArray, "ms-python.python") + assert.Contains(t, extArray, "ms-toolsai.jupyter") +} + +func TestUpdateSettings_NewConnection(t *testing.T) { + settings := map[string]any{} + + missing := &missingSettings{ + portRange: true, + platform: true, + extensions: []string{"ms-python.python", "ms-toolsai.jupyter"}, + } + + updateSettings(settings, "new-conn", missing) + + portRangeObj := settings["remote.SSH.serverPickPortsFromRange"].(map[string]any) + assert.Equal(t, "4000-4005", portRangeObj["new-conn"]) + + platformObj := settings["remote.SSH.remotePlatform"].(map[string]any) + assert.Equal(t, "linux", platformObj["new-conn"]) + + extArray := settings["remote.SSH.defaultExtensions"].([]any) + assert.Len(t, extArray, 2) + assert.Contains(t, extArray, "ms-python.python") + assert.Contains(t, extArray, "ms-toolsai.jupyter") +} + +func TestUpdateSettings_GlobalExtensions(t *testing.T) { + // Verify that extensions are global, not per-connection + settings := map[string]any{ + "remote.SSH.defaultExtensions": []any{ + "ms-python.python", + }, + } + + missing := &missingSettings{ + extensions: []string{"ms-toolsai.jupyter"}, + } + + updateSettings(settings, "conn-a", missing) + + extArray := settings["remote.SSH.defaultExtensions"].([]any) + assert.Len(t, extArray, 2) + assert.Contains(t, extArray, "ms-python.python") + assert.Contains(t, extArray, "ms-toolsai.jupyter") + + // Update for another connection should use the same global array + missing2 := &missingSettings{ + extensions: []string{"another.extension"}, + } + + updateSettings(settings, "conn-b", missing2) + + extArray = settings["remote.SSH.defaultExtensions"].([]any) + assert.Len(t, extArray, 3) + assert.Contains(t, extArray, "ms-python.python") + assert.Contains(t, extArray, "ms-toolsai.jupyter") + assert.Contains(t, extArray, "another.extension") +} + +func TestUpdateSettings_MergeExtensions(t *testing.T) { + settings := map[string]any{ + "remote.SSH.defaultExtensions": []any{ + "existing.extension", + "ms-python.python", + }, + } + + missing := &missingSettings{ + extensions: []string{"ms-python.python", "ms-toolsai.jupyter"}, + } + + updateSettings(settings, "test-conn", missing) + + extArray := settings["remote.SSH.defaultExtensions"].([]any) + assert.Len(t, extArray, 3) + assert.Contains(t, extArray, "existing.extension") + assert.Contains(t, extArray, "ms-python.python") + assert.Contains(t, extArray, "ms-toolsai.jupyter") +} + +func TestUpdateSettings_PartialUpdate(t *testing.T) { + settings := map[string]any{ + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "test-conn": "4000-4005", // Already correct + }, + "remote.SSH.remotePlatform": map[string]any{ + "other-conn": "linux", + }, + "remote.SSH.defaultExtensions": []any{ + "ms-python.python", + "ms-toolsai.jupyter", + }, + } + + missing := &missingSettings{ + portRange: false, // Already set + platform: true, // Needs update + extensions: nil, // Already present + } + + updateSettings(settings, "test-conn", missing) + + // Port range should not be modified + portRangeObj := settings["remote.SSH.serverPickPortsFromRange"].(map[string]any) + assert.Equal(t, "4000-4005", portRangeObj["test-conn"]) + + // Platform should be added for test-conn + platformObj := settings["remote.SSH.remotePlatform"].(map[string]any) + assert.Equal(t, "linux", platformObj["test-conn"]) + assert.Equal(t, "linux", platformObj["other-conn"]) // Preserve other connection + + // Extensions should not be modified + extArray := settings["remote.SSH.defaultExtensions"].([]any) + assert.Len(t, extArray, 2) +} + +func TestBackupSettings(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + originalContent := []byte(`{"key": "value"}`) + err := os.WriteFile(settingsPath, originalContent, 0o600) + require.NoError(t, err) + + ctx, _ := cmdio.NewTestContextWithStderr(context.Background()) + err = backupSettings(ctx, settingsPath) + require.NoError(t, err) + + backupPath := settingsPath + ".bak" + backupContent, err := os.ReadFile(backupPath) + require.NoError(t, err) + assert.Equal(t, originalContent, backupContent) +} + +func TestSaveSettings_Formatting(t *testing.T) { + tmpDir := t.TempDir() + settingsPath := filepath.Join(tmpDir, "settings.json") + + settings := map[string]any{ + "remote.SSH.serverPickPortsFromRange": map[string]any{ + "test-conn": "4000-4005", + }, + "editor.fontSize": 14, + } + + err := saveSettings(settingsPath, settings) + require.NoError(t, err) + + content, err := os.ReadFile(settingsPath) + require.NoError(t, err) + + // Verify it's valid JSON + var parsed map[string]any + err = json.Unmarshal(content, &parsed) + require.NoError(t, err) + + // Verify formatting (should have 2-space indent) + assert.Contains(t, string(content), " \"remote.SSH.serverPickPortsFromRange\"") + + // Verify permissions + info, err := os.Stat(settingsPath) + require.NoError(t, err) + if runtime.GOOS != "windows" { + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) + } +} + +func TestMissingSettings_IsEmpty(t *testing.T) { + empty := &missingSettings{} + assert.True(t, empty.isEmpty()) + + notEmpty := &missingSettings{portRange: true} + assert.False(t, notEmpty.isEmpty()) + + notEmpty2 := &missingSettings{extensions: []string{"ext"}} + assert.False(t, notEmpty2.isEmpty()) +} + +func TestGetManualInstructions_VSCode(t *testing.T) { + instructions := GetManualInstructions(vscodeIDE, "test-conn") + + assert.Contains(t, instructions, "VS Code") + assert.Contains(t, instructions, "test-conn") + assert.Contains(t, instructions, "4000-4005") + assert.Contains(t, instructions, "linux") + assert.Contains(t, instructions, "ms-python.python") + assert.Contains(t, instructions, "ms-toolsai.jupyter") + assert.Contains(t, instructions, "remote.SSH.serverPickPortsFromRange") + assert.Contains(t, instructions, "remote.SSH.remotePlatform") + assert.Contains(t, instructions, "remote.SSH.defaultExtensions") +} + +func TestGetManualInstructions_Cursor(t *testing.T) { + instructions := GetManualInstructions("cursor", "my-connection") + + assert.Contains(t, instructions, "Cursor") + assert.Contains(t, instructions, "my-connection") + assert.Contains(t, instructions, "4000-4005") + assert.Contains(t, instructions, "linux") + assert.Contains(t, instructions, "ms-python.python") + assert.Contains(t, instructions, "ms-toolsai.jupyter") +} diff --git a/go.mod b/go.mod index 1afeb28c07..8ec577eeae 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,9 @@ require github.com/google/jsonschema-go v0.4.2 // MIT require gopkg.in/yaml.v3 v3.0.1 // indirect +// Dependencies for the experimental SSH commands +require github.com/tidwall/jsonc v0.3.2 // MIT + require ( cloud.google.com/go/auth v0.18.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect diff --git a/go.sum b/go.sum index 1894726358..5ee797e534 100644 --- a/go.sum +++ b/go.sum @@ -216,6 +216,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/jsonc v0.3.2 h1:ZTKrmejRlAJYdn0kcaFqRAKlxxFIC21pYq8vLa4p2Wc= +github.com/tidwall/jsonc v0.3.2/go.mod h1:dw+3CIxqHi+t8eFSpzzMlcVYxKp08UP5CD8/uSFCyJE= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=