From 221451a5179c8f139819a315b80d0ecb0e7220c3 Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Wed, 25 Jun 2025 17:24:06 +0200 Subject: [PATCH] fix: correctly parse negative boolean flags --- cmd/root.go | 86 ++++++++++++++++++++++++++++++++++---------------- cmd/upgrade.go | 2 +- cmd/utils.go | 2 +- 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index fbd27541..3ad52f3f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -201,42 +201,42 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) *settings.Server { server, err := st.Settings.GetServer() checkErr(err) - if val, set := getParamB(flags, "root"); set { + if val, set := getStringParamB(flags, "root"); set { server.Root = val } - if val, set := getParamB(flags, "baseurl"); set { + if val, set := getStringParamB(flags, "baseurl"); set { server.BaseURL = val } - if val, set := getParamB(flags, "log"); set { + if val, set := getStringParamB(flags, "log"); set { server.Log = val } isSocketSet := false isAddrSet := false - if val, set := getParamB(flags, "address"); set { + if val, set := getStringParamB(flags, "address"); set { server.Address = val isAddrSet = isAddrSet || set } - if val, set := getParamB(flags, "port"); set { + if val, set := getStringParamB(flags, "port"); set { server.Port = val isAddrSet = isAddrSet || set } - if val, set := getParamB(flags, "key"); set { + if val, set := getStringParamB(flags, "key"); set { server.TLSKey = val isAddrSet = isAddrSet || set } - if val, set := getParamB(flags, "cert"); set { + if val, set := getStringParamB(flags, "cert"); set { server.TLSCert = val isAddrSet = isAddrSet || set } - if val, set := getParamB(flags, "socket"); set { + if val, set := getStringParamB(flags, "socket"); set { server.Socket = val isSocketSet = isSocketSet || set } @@ -250,33 +250,62 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) *settings.Server { server.Socket = "" } - _, disableThumbnails := getParamB(flags, "disable-thumbnails") + disableThumbnails := getBoolParam(flags, "disable-thumbnails") server.EnableThumbnails = !disableThumbnails - _, disablePreviewResize := getParamB(flags, "disable-preview-resize") + disablePreviewResize := getBoolParam(flags, "disable-preview-resize") server.ResizePreview = !disablePreviewResize - _, disableTypeDetectionByHeader := getParamB(flags, "disable-type-detection-by-header") + disableTypeDetectionByHeader := getBoolParam(flags, "disable-type-detection-by-header") server.TypeDetectionByHeader = !disableTypeDetectionByHeader - _, disableExec := getParamB(flags, "disable-exec") + disableExec := getBoolParam(flags, "disable-exec") server.EnableExec = !disableExec - if val, set := getParamB(flags, "token-expiration-time"); set { + if val, set := getStringParamB(flags, "token-expiration-time"); set { server.TokenExpirationTime = val } return server } -// getParamB returns a parameter as a string and a boolean to tell if it is different from the default +// getStringParamB returns a parameter as a string and a boolean to tell if it is different from the default // // NOTE: we could simply bind the flags to viper and use IsSet. // Although there is a bug on Viper that always returns true on IsSet // if a flag is binded. Our alternative way is to manually check // the flag and then the value from env/config/gotten by viper. // https://github.com/spf13/viper/pull/331 -func getParamB(flags *pflag.FlagSet, key string) (string, bool) { +func getBoolParamB(flags *pflag.FlagSet, key string) (bool, bool) { + value, _ := flags.GetBool(key) + + // If set on Flags, use it. + if flags.Changed(key) { + return value, true + } + + // If set through viper (env, config), return it. + if v.IsSet(key) { + return v.GetBool(key), true + } + + // Otherwise use default value on flags. + return value, false +} + +func getBoolParam(flags *pflag.FlagSet, key string) bool { + val, _ := getBoolParamB(flags, key) + return val +} + +// getStringParamB returns a parameter as a string and a boolean to tell if it is different from the default +// +// NOTE: we could simply bind the flags to viper and use IsSet. +// Although there is a bug on Viper that always returns true on IsSet +// if a flag is binded. Our alternative way is to manually check +// the flag and then the value from env/config/gotten by viper. +// https://github.com/spf13/viper/pull/331 +func getStringParamB(flags *pflag.FlagSet, key string) (string, bool) { value, _ := flags.GetString(key) // If set on Flags, use it. @@ -293,8 +322,8 @@ func getParamB(flags *pflag.FlagSet, key string) (string, bool) { return value, false } -func getParam(flags *pflag.FlagSet, key string) string { - val, _ := getParamB(flags, key) +func getStringParam(flags *pflag.FlagSet, key string) string { + val, _ := getStringParamB(flags, key) return val } @@ -349,7 +378,7 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) { } var err error - if _, noauth := getParamB(flags, "noauth"); noauth { + if _, noauth := getStringParamB(flags, "noauth"); noauth { set.AuthMethod = auth.MethodNoAuth err = d.store.Auth.Save(&auth.NoAuth{}) } else { @@ -362,27 +391,27 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) { checkErr(err) ser := &settings.Server{ - BaseURL: getParam(flags, "baseurl"), - Port: getParam(flags, "port"), - Log: getParam(flags, "log"), - TLSKey: getParam(flags, "key"), - TLSCert: getParam(flags, "cert"), - Address: getParam(flags, "address"), - Root: getParam(flags, "root"), + BaseURL: getStringParam(flags, "baseurl"), + Port: getStringParam(flags, "port"), + Log: getStringParam(flags, "log"), + TLSKey: getStringParam(flags, "key"), + TLSCert: getStringParam(flags, "cert"), + Address: getStringParam(flags, "address"), + Root: getStringParam(flags, "root"), } err = d.store.Settings.SaveServer(ser) checkErr(err) - username := getParam(flags, "username") - password := getParam(flags, "password") + username := getStringParam(flags, "username") + password := getStringParam(flags, "password") if password == "" { var pwd string pwd, err = users.RandomPwd() checkErr(err) - log.Println("Generated random admin password for quick setup:", pwd) + log.Println("Randomly generated password for user 'admin':", pwd) password, err = users.HashPwd(pwd) checkErr(err) @@ -420,6 +449,7 @@ func initConfig() { v.SetEnvPrefix("FB") v.AutomaticEnv() v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) if err := v.ReadInConfig(); err != nil { var configParseError v.ConfigParseError diff --git a/cmd/upgrade.go b/cmd/upgrade.go index 83a0729c..f6966e2e 100644 --- a/cmd/upgrade.go +++ b/cmd/upgrade.go @@ -25,7 +25,7 @@ this version.`, flags := cmd.Flags() oldDB := mustGetString(flags, "old.database") oldConf := mustGetString(flags, "old.config") - err := importer.Import(oldDB, oldConf, getParam(flags, "database")) + err := importer.Import(oldDB, oldConf, getStringParam(flags, "database")) checkErr(err) }, } diff --git a/cmd/utils.go b/cmd/utils.go index 78f48d13..49cc2d2f 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -86,7 +86,7 @@ func python(fn pythonFunc, cfg pythonConfig) cobraFunc { return func(cmd *cobra.Command, args []string) { data := pythonData{hadDB: true} - path := getParam(cmd.Flags(), "database") + path := getStringParam(cmd.Flags(), "database") absPath, err := filepath.Abs(path) if err != nil { panic(err)