Skip to content

Commit 4ce74a1

Browse files
committed
chore: introduce flags helper for consistency
1 parent cf8b50b commit 4ce74a1

File tree

11 files changed

+217
-28
lines changed

11 files changed

+217
-28
lines changed

cmd/model/get.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/openfga/cli/internal/authorizationmodel"
2626
"github.com/openfga/cli/internal/cmdutils"
27+
"github.com/openfga/cli/internal/flags"
2728
"github.com/openfga/cli/internal/output"
2829
)
2930

@@ -77,8 +78,8 @@ func init() {
7778
getCmd.Flags().StringArray("field", []string{"model"}, "Fields to display, choices are: id, created_at and model") //nolint:lll
7879
getCmd.Flags().Var(&getOutputFormat, "format", `Authorization model output format. Can be "fga" or "json"`)
7980

80-
if err := getCmd.MarkFlagRequired("store-id"); err != nil {
81-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/get", err)
81+
if err := flags.SetFlagRequired(getCmd, "store-id", "cmd/models/get", false); err != nil {
82+
fmt.Print(err)
8283
os.Exit(1)
8384
}
8485
}

cmd/model/list.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727

2828
"github.com/openfga/cli/internal/authorizationmodel"
2929
"github.com/openfga/cli/internal/cmdutils"
30+
"github.com/openfga/cli/internal/flags"
3031
"github.com/openfga/cli/internal/output"
3132
)
3233

@@ -112,8 +113,8 @@ func init() {
112113
listCmd.Flags().String("store-id", "", "Store ID")
113114
listCmd.Flags().StringArray("field", []string{"id", "created_at"}, "Fields to display, choices are: id, created_at and model") //nolint:lll
114115

115-
if err := listCmd.MarkFlagRequired("store-id"); err != nil {
116-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/list", err)
116+
if err := flags.SetFlagRequired(listCmd, "store-id", "cmd/models/list", false); err != nil {
117+
fmt.Print(err)
117118
os.Exit(1)
118119
}
119120
}

cmd/model/test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/spf13/cobra"
2525

2626
"github.com/openfga/cli/internal/cmdutils"
27+
"github.com/openfga/cli/internal/flags"
2728
"github.com/openfga/cli/internal/output"
2829
"github.com/openfga/cli/internal/storetest"
2930
)
@@ -99,8 +100,8 @@ func init() {
99100
testCmd.Flags().Bool("verbose", false, "Print verbose JSON output")
100101
testCmd.Flags().Bool("suppress-summary", false, "Suppress the plain text summary output")
101102

102-
if err := testCmd.MarkFlagRequired("tests"); err != nil {
103-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/test", err)
103+
if err := flags.SetFlagRequired(testCmd, "tests", "cmd/models/test", false); err != nil {
104+
fmt.Print(err)
104105
os.Exit(1)
105106
}
106107
}

cmd/model/write.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727

2828
"github.com/openfga/cli/internal/authorizationmodel"
2929
"github.com/openfga/cli/internal/cmdutils"
30+
"github.com/openfga/cli/internal/flags"
3031
"github.com/openfga/cli/internal/output"
3132
"github.com/openfga/cli/internal/utils"
3233
)
@@ -106,8 +107,8 @@ func init() {
106107
writeCmd.Flags().String("file", "", "File Name. The file should have the model in the JSON or DSL format")
107108
writeCmd.Flags().Var(&writeInputFormat, "format", `Authorization model input format. Can be "fga", "json", or "modular"`) //nolint:lll
108109

109-
if err := writeCmd.MarkFlagRequired("store-id"); err != nil {
110-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/write", err)
110+
if err := flags.SetFlagRequired(writeCmd, "tests", "cmd/models/write", false); err != nil {
111+
fmt.Print(err)
111112
os.Exit(1)
112113
}
113114
}

cmd/query/list-users.go

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/spf13/cobra"
2828

2929
"github.com/openfga/cli/internal/cmdutils"
30+
"github.com/openfga/cli/internal/flags"
3031
"github.com/openfga/cli/internal/output"
3132
)
3233

@@ -138,18 +139,8 @@ func init() {
138139
listUsersCmd.Flags().String("relation", "", "Relation to evaluate on")
139140
listUsersCmd.Flags().String("user-filter", "", "Filter the responses can be in the formats <type> (to filter objects and typed public bound access) or <type>#<relation> (to filter usersets)") //nolint:lll
140141

141-
if err := listUsersCmd.MarkFlagRequired("object"); err != nil {
142-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err)
143-
os.Exit(1)
144-
}
145-
146-
if err := listUsersCmd.MarkFlagRequired("relation"); err != nil {
147-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err)
148-
os.Exit(1)
149-
}
150-
151-
if err := listUsersCmd.MarkFlagRequired("user-filter"); err != nil {
152-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/query/list-users", err)
142+
if err := flags.SetFlagsRequired(listUsersCmd, []string{"object", "relation", "user-filter"}, "cmd/query/list-users", false); err != nil {
143+
fmt.Print(err)
153144
os.Exit(1)
154145
}
155146
}

cmd/store/delete.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525

2626
"github.com/openfga/cli/internal/cmdutils"
2727
"github.com/openfga/cli/internal/confirmation"
28+
"github.com/openfga/cli/internal/flags"
2829
"github.com/openfga/cli/internal/output"
2930
)
3031

@@ -72,8 +73,7 @@ func init() {
7273
deleteCmd.Flags().String("store-id", "", "Store ID")
7374
deleteCmd.Flags().Bool("force", false, "Force delete without confirmation")
7475

75-
err := deleteCmd.MarkFlagRequired("store-id")
76-
if err != nil {
76+
if err := flags.SetFlagRequired(deleteCmd, "store-id", "cmd/store/delete", false); err != nil {
7777
fmt.Print(err)
7878
os.Exit(1)
7979
}

cmd/store/export.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/openfga/cli/internal/cmdutils"
3131
"github.com/openfga/cli/internal/confirmation"
3232
"github.com/openfga/cli/internal/fga"
33+
"github.com/openfga/cli/internal/flags"
3334
"github.com/openfga/cli/internal/output"
3435
"github.com/openfga/cli/internal/storetest"
3536
"github.com/openfga/cli/internal/tuple"
@@ -193,8 +194,7 @@ func init() {
193194
exportCmd.Flags().String("model-id", "", "Authorization Model ID")
194195
exportCmd.Flags().Uint("max-tuples", defaultMaxTupleCount, "max number of tuples to return in the output")
195196

196-
err := exportCmd.MarkFlagRequired("store-id")
197-
if err != nil {
197+
if err := flags.SetFlagRequired(exportCmd, "store-id", "cmd/store/export", false); err != nil {
198198
fmt.Print(err)
199199
os.Exit(1)
200200
}

cmd/store/get.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626

2727
"github.com/openfga/cli/internal/cmdutils"
2828
"github.com/openfga/cli/internal/fga"
29+
"github.com/openfga/cli/internal/flags"
2930
"github.com/openfga/cli/internal/output"
3031
)
3132

@@ -65,8 +66,7 @@ var getCmd = &cobra.Command{
6566
func init() {
6667
getCmd.Flags().String("store-id", "", "Store ID")
6768

68-
err := getCmd.MarkFlagRequired("store-id")
69-
if err != nil {
69+
if err := flags.SetFlagRequired(getCmd, "store-id", "cmd/store/get", false); err != nil {
7070
fmt.Print(err)
7171
os.Exit(1)
7272
}

cmd/store/import.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/openfga/cli/internal/authorizationmodel"
3535
"github.com/openfga/cli/internal/cmdutils"
3636
"github.com/openfga/cli/internal/fga"
37+
"github.com/openfga/cli/internal/flags"
3738
"github.com/openfga/cli/internal/output"
3839
"github.com/openfga/cli/internal/storetest"
3940
"github.com/openfga/cli/internal/tuple"
@@ -339,8 +340,8 @@ func init() {
339340
importCmd.Flags().Int("max-tuples-per-write", tuple.MaxTuplesPerWrite, "Max tuples per write chunk.")
340341
importCmd.Flags().Int("max-parallel-requests", tuple.MaxParallelRequests, "Max number of requests to issue to the server in parallel.") //nolint:lll
341342

342-
if err := importCmd.MarkFlagRequired("file"); err != nil {
343-
fmt.Printf("error setting flag as required - %v: %v\n", "cmd/models/write", err)
343+
if err := flags.SetFlagRequired(importCmd, "file", "cmd/store/import", false); err != nil {
344+
fmt.Print(err)
344345
os.Exit(1)
345346
}
346347
}

internal/flags/flags.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Package flags provides utility functions for working with cobra command flags.
2+
// It simplifies the process of marking flags as required and handling related errors.
3+
package flags
4+
5+
import (
6+
"errors"
7+
"fmt"
8+
9+
"github.com/spf13/cobra"
10+
)
11+
12+
const flagRequiredErrorMsg = "error setting %s flag as required - %s: %v"
13+
14+
// SetFlagRequired marks a single flag as required for a cobra command.
15+
// If isPersistent is true, it marks the persistent flag as required.
16+
// Returns an error if the flag cannot be marked as required.
17+
func SetFlagRequired(cmd *cobra.Command, flag string, location string, isPersistent bool) error {
18+
if isPersistent {
19+
if err := cmd.MarkPersistentFlagRequired(flag); err != nil {
20+
return fmt.Errorf(flagRequiredErrorMsg, flag, location, err)
21+
}
22+
} else {
23+
if err := cmd.MarkFlagRequired(flag); err != nil {
24+
return fmt.Errorf(flagRequiredErrorMsg, flag, location, err)
25+
}
26+
}
27+
28+
return nil
29+
}
30+
31+
// SetFlagsRequired marks multiple flags as required for a cobra command.
32+
// If isPersistent is true, it marks the persistent flags as required.
33+
// Returns a joined error if any flag cannot be marked as required.
34+
func SetFlagsRequired(cmd *cobra.Command, flags []string, location string, isPersistent bool) error {
35+
flagErrors := make([]error, len(flags))
36+
37+
for i, flag := range flags {
38+
if err := SetFlagRequired(cmd, flag, location, isPersistent); err != nil {
39+
flagErrors[i] = err
40+
}
41+
}
42+
43+
return errors.Join(flagErrors...)
44+
}

0 commit comments

Comments
 (0)