From a7b88cf1764d6056bebee1519660e104340c4ee1 Mon Sep 17 00:00:00 2001 From: David Levy Date: Thu, 5 Feb 2026 20:08:08 -0600 Subject: [PATCH] fix: validate config file extension before use Provide clear error message when user specifies a config file with unsupported extension (e.g., .txt, .json). Only .yaml, .yml, and no extension are allowed. --- internal/config/config.go | 14 ++++++++--- internal/config/viper.go | 24 +++++++++++++++++- internal/config/viper_test.go | 47 +++++++++++++++++++++++++++++++++-- 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 1b24695c..0bc9490b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,13 +4,14 @@ package config import ( + "os" + "path/filepath" + "testing" + . "github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig" "github.com/microsoft/go-sqlcmd/internal/io/file" "github.com/microsoft/go-sqlcmd/internal/io/folder" "github.com/microsoft/go-sqlcmd/internal/pal" - "os" - "path/filepath" - "testing" ) var config Sqlconfig @@ -26,8 +27,13 @@ func SetFileName(name string) { filename = name + // Validate extension before creating the file + err := validateConfigFileExtension(filename) + checkErr(err) + file.CreateEmptyIfNotExists(filename) - configureViper(filename) + err = configureViper(filename) + checkErr(err) } func SetFileNameForTest(t *testing.T) { diff --git a/internal/config/viper.go b/internal/config/viper.go index 9409fd8d..33a9b373 100644 --- a/internal/config/viper.go +++ b/internal/config/viper.go @@ -5,6 +5,10 @@ package config import ( "bytes" + "path/filepath" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/spf13/viper" "gopkg.in/yaml.v2" @@ -56,16 +60,34 @@ func GetConfigFileUsed() string { return viper.ConfigFileUsed() } +// validateConfigFileExtension checks if the config file has a supported extension. +// Allows .yaml, .yml, and no extension (for default sqlconfig file). +func validateConfigFileExtension(configFile string) error { + ext := strings.ToLower(filepath.Ext(configFile)) + if ext == "" || ext == ".yaml" || ext == ".yml" { + return nil + } + return localizer.Errorf( + "Configuration files must use YAML format (.yaml or .yml extension). "+ + "File '%s' has unsupported extension '%s'.", + configFile, ext) +} + // configureViper initializes the Viper library with the given configuration file. // This function sets the configuration file type to "yaml" and sets the environment variable prefix to "SQLCMD". // It also sets the configuration file to use to the one provided as an argument to the function. // This function is intended to be called at the start of the application to configure Viper before any other code uses it. -func configureViper(configFile string) { +func configureViper(configFile string) error { if configFile == "" { panic("Must provide configFile") } + if err := validateConfigFileExtension(configFile); err != nil { + return err + } + viper.SetConfigType("yaml") viper.SetEnvPrefix("SQLCMD") viper.SetConfigFile(configFile) + return nil } diff --git a/internal/config/viper_test.go b/internal/config/viper_test.go index 3a606f6d..a4ddc755 100644 --- a/internal/config/viper_test.go +++ b/internal/config/viper_test.go @@ -4,16 +4,59 @@ package config import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func Test_configureViper(t *testing.T) { assert.Panics(t, func() { - configureViper("") + _ = configureViper("") }) } +func Test_configureViperValidExtensions(t *testing.T) { + tests := []string{"config.yaml", "config.yml", "sqlconfig", "/path/to/config.YAML"} + for _, name := range tests { + t.Run(name, func(t *testing.T) { + err := configureViper(name) + assert.NoError(t, err) + }) + } +} + +func Test_configureViperInvalidExtension(t *testing.T) { + err := configureViper("config.txt") + assert.Error(t, err) + assert.Contains(t, err.Error(), "YAML format") + assert.Contains(t, err.Error(), ".txt") +} + +func Test_validateConfigFileExtension(t *testing.T) { + tests := []struct { + name string + file string + wantErr bool + }{ + {"yaml extension", "config.yaml", false}, + {"yml extension", "config.yml", false}, + {"no extension", "sqlconfig", false}, + {"uppercase YAML", "config.YAML", false}, + {"txt extension", "config.txt", true}, + {"json extension", "config.json", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateConfigFileExtension(tt.file) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + func Test_Load(t *testing.T) { SetFileNameForTest(t) Clean()