fix: correctly parse negative boolean flags

This commit is contained in:
Henrique Dias 2025-06-25 17:24:06 +02:00
parent f46641b038
commit 221451a517
No known key found for this signature in database
3 changed files with 60 additions and 30 deletions

View File

@ -201,42 +201,42 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) *settings.Server {
server, err := st.Settings.GetServer() server, err := st.Settings.GetServer()
checkErr(err) checkErr(err)
if val, set := getParamB(flags, "root"); set { if val, set := getStringParamB(flags, "root"); set {
server.Root = val server.Root = val
} }
if val, set := getParamB(flags, "baseurl"); set { if val, set := getStringParamB(flags, "baseurl"); set {
server.BaseURL = val server.BaseURL = val
} }
if val, set := getParamB(flags, "log"); set { if val, set := getStringParamB(flags, "log"); set {
server.Log = val server.Log = val
} }
isSocketSet := false isSocketSet := false
isAddrSet := false isAddrSet := false
if val, set := getParamB(flags, "address"); set { if val, set := getStringParamB(flags, "address"); set {
server.Address = val server.Address = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getParamB(flags, "port"); set { if val, set := getStringParamB(flags, "port"); set {
server.Port = val server.Port = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getParamB(flags, "key"); set { if val, set := getStringParamB(flags, "key"); set {
server.TLSKey = val server.TLSKey = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getParamB(flags, "cert"); set { if val, set := getStringParamB(flags, "cert"); set {
server.TLSCert = val server.TLSCert = val
isAddrSet = isAddrSet || set isAddrSet = isAddrSet || set
} }
if val, set := getParamB(flags, "socket"); set { if val, set := getStringParamB(flags, "socket"); set {
server.Socket = val server.Socket = val
isSocketSet = isSocketSet || set isSocketSet = isSocketSet || set
} }
@ -250,33 +250,62 @@ func getRunParams(flags *pflag.FlagSet, st *storage.Storage) *settings.Server {
server.Socket = "" server.Socket = ""
} }
_, disableThumbnails := getParamB(flags, "disable-thumbnails") disableThumbnails := getBoolParam(flags, "disable-thumbnails")
server.EnableThumbnails = !disableThumbnails server.EnableThumbnails = !disableThumbnails
_, disablePreviewResize := getParamB(flags, "disable-preview-resize") disablePreviewResize := getBoolParam(flags, "disable-preview-resize")
server.ResizePreview = !disablePreviewResize server.ResizePreview = !disablePreviewResize
_, disableTypeDetectionByHeader := getParamB(flags, "disable-type-detection-by-header") disableTypeDetectionByHeader := getBoolParam(flags, "disable-type-detection-by-header")
server.TypeDetectionByHeader = !disableTypeDetectionByHeader server.TypeDetectionByHeader = !disableTypeDetectionByHeader
_, disableExec := getParamB(flags, "disable-exec") disableExec := getBoolParam(flags, "disable-exec")
server.EnableExec = !disableExec server.EnableExec = !disableExec
if val, set := getParamB(flags, "token-expiration-time"); set { if val, set := getStringParamB(flags, "token-expiration-time"); set {
server.TokenExpirationTime = val server.TokenExpirationTime = val
} }
return server return server
} }
// getParamB returns a parameter as a string and a boolean to tell if it is different from the default // getStringParamB returns a parameter as a string and a boolean to tell if it is different from the default
// //
// NOTE: we could simply bind the flags to viper and use IsSet. // NOTE: we could simply bind the flags to viper and use IsSet.
// Although there is a bug on Viper that always returns true on IsSet // Although there is a bug on Viper that always returns true on IsSet
// if a flag is binded. Our alternative way is to manually check // if a flag is binded. Our alternative way is to manually check
// the flag and then the value from env/config/gotten by viper. // the flag and then the value from env/config/gotten by viper.
// https://github.com/spf13/viper/pull/331 // https://github.com/spf13/viper/pull/331
func getParamB(flags *pflag.FlagSet, key string) (string, bool) { func getBoolParamB(flags *pflag.FlagSet, key string) (bool, bool) {
value, _ := flags.GetBool(key)
// If set on Flags, use it.
if flags.Changed(key) {
return value, true
}
// If set through viper (env, config), return it.
if v.IsSet(key) {
return v.GetBool(key), true
}
// Otherwise use default value on flags.
return value, false
}
func getBoolParam(flags *pflag.FlagSet, key string) bool {
val, _ := getBoolParamB(flags, key)
return val
}
// getStringParamB returns a parameter as a string and a boolean to tell if it is different from the default
//
// NOTE: we could simply bind the flags to viper and use IsSet.
// Although there is a bug on Viper that always returns true on IsSet
// if a flag is binded. Our alternative way is to manually check
// the flag and then the value from env/config/gotten by viper.
// https://github.com/spf13/viper/pull/331
func getStringParamB(flags *pflag.FlagSet, key string) (string, bool) {
value, _ := flags.GetString(key) value, _ := flags.GetString(key)
// If set on Flags, use it. // If set on Flags, use it.
@ -293,8 +322,8 @@ func getParamB(flags *pflag.FlagSet, key string) (string, bool) {
return value, false return value, false
} }
func getParam(flags *pflag.FlagSet, key string) string { func getStringParam(flags *pflag.FlagSet, key string) string {
val, _ := getParamB(flags, key) val, _ := getStringParamB(flags, key)
return val return val
} }
@ -349,7 +378,7 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) {
} }
var err error var err error
if _, noauth := getParamB(flags, "noauth"); noauth { if _, noauth := getStringParamB(flags, "noauth"); noauth {
set.AuthMethod = auth.MethodNoAuth set.AuthMethod = auth.MethodNoAuth
err = d.store.Auth.Save(&auth.NoAuth{}) err = d.store.Auth.Save(&auth.NoAuth{})
} else { } else {
@ -362,27 +391,27 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) {
checkErr(err) checkErr(err)
ser := &settings.Server{ ser := &settings.Server{
BaseURL: getParam(flags, "baseurl"), BaseURL: getStringParam(flags, "baseurl"),
Port: getParam(flags, "port"), Port: getStringParam(flags, "port"),
Log: getParam(flags, "log"), Log: getStringParam(flags, "log"),
TLSKey: getParam(flags, "key"), TLSKey: getStringParam(flags, "key"),
TLSCert: getParam(flags, "cert"), TLSCert: getStringParam(flags, "cert"),
Address: getParam(flags, "address"), Address: getStringParam(flags, "address"),
Root: getParam(flags, "root"), Root: getStringParam(flags, "root"),
} }
err = d.store.Settings.SaveServer(ser) err = d.store.Settings.SaveServer(ser)
checkErr(err) checkErr(err)
username := getParam(flags, "username") username := getStringParam(flags, "username")
password := getParam(flags, "password") password := getStringParam(flags, "password")
if password == "" { if password == "" {
var pwd string var pwd string
pwd, err = users.RandomPwd() pwd, err = users.RandomPwd()
checkErr(err) checkErr(err)
log.Println("Generated random admin password for quick setup:", pwd) log.Println("Randomly generated password for user 'admin':", pwd)
password, err = users.HashPwd(pwd) password, err = users.HashPwd(pwd)
checkErr(err) checkErr(err)
@ -420,6 +449,7 @@ func initConfig() {
v.SetEnvPrefix("FB") v.SetEnvPrefix("FB")
v.AutomaticEnv() v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
if err := v.ReadInConfig(); err != nil { if err := v.ReadInConfig(); err != nil {
var configParseError v.ConfigParseError var configParseError v.ConfigParseError

View File

@ -25,7 +25,7 @@ this version.`,
flags := cmd.Flags() flags := cmd.Flags()
oldDB := mustGetString(flags, "old.database") oldDB := mustGetString(flags, "old.database")
oldConf := mustGetString(flags, "old.config") oldConf := mustGetString(flags, "old.config")
err := importer.Import(oldDB, oldConf, getParam(flags, "database")) err := importer.Import(oldDB, oldConf, getStringParam(flags, "database"))
checkErr(err) checkErr(err)
}, },
} }

View File

@ -86,7 +86,7 @@ func python(fn pythonFunc, cfg pythonConfig) cobraFunc {
return func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) {
data := pythonData{hadDB: true} data := pythonData{hadDB: true}
path := getParam(cmd.Flags(), "database") path := getStringParam(cmd.Flags(), "database")
absPath, err := filepath.Abs(path) absPath, err := filepath.Abs(path)
if err != nil { if err != nil {
panic(err) panic(err)