Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions internal/pkg/auth/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,12 @@ func cleanup(server *http.Server) {
}()
}

func openBrowser(pageUrl string) error {
var err error
func openBrowser(pageUrl string) (err error) {
err = utils.ValidateURLDomain(pageUrl)
if err != nil {
return err
}

switch runtime.GOOS {
case "freebsd", "linux":
// We need to use the windows way on WSL, otherwise we do not pass query
Expand Down
9 changes: 9 additions & 0 deletions internal/pkg/auth/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,21 @@ func parseWellKnownConfiguration(httpClient apiClient, wellKnownConfigURL string
if wellKnownConfig.Issuer == "" {
return nil, fmt.Errorf("found no issuer")
}
if utils.ValidateURLDomain(wellKnownConfig.Issuer) != nil {
return nil, fmt.Errorf("issuer is invalid")
}
if wellKnownConfig.AuthorizationEndpoint == "" {
return nil, fmt.Errorf("found no authorization endpoint")
}
if utils.ValidateURLDomain(wellKnownConfig.AuthorizationEndpoint) != nil {
return nil, fmt.Errorf("authorization endpoint is invalid")
}
if wellKnownConfig.TokenEndpoint == "" {
return nil, fmt.Errorf("found no token endpoint")
}
if utils.ValidateURLDomain(wellKnownConfig.TokenEndpoint) != nil {
return nil, fmt.Errorf("token endpoint is invalid")
}

err = SetAuthField(IDP_TOKEN_ENDPOINT, wellKnownConfig.TokenEndpoint)
if err != nil {
Expand Down
14 changes: 7 additions & 7 deletions internal/pkg/auth/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ func TestParseWellKnownConfig(t *testing.T) {
{
name: "success",
getFails: false,
getResponse: `{"issuer":"issuer","authorization_endpoint":"auth","token_endpoint":"token"}`,
getResponse: `{"issuer":"https://issuer.stackit.cloud/endpoint","authorization_endpoint":"https://auth.stackit.cloud/enpoint","token_endpoint":"https://token.stackit.cloud/endpoint"}`,
isValid: true,
expected: &wellKnownConfig{
Issuer: "issuer",
AuthorizationEndpoint: "auth",
TokenEndpoint: "token",
Issuer: "https://issuer.stackit.cloud/endpoint",
AuthorizationEndpoint: "https://auth.stackit.cloud/enpoint",
TokenEndpoint: "https://token.stackit.cloud/endpoint",
},
},
{
Expand All @@ -158,21 +158,21 @@ func TestParseWellKnownConfig(t *testing.T) {
{
name: "missing_issuer",
getFails: true,
getResponse: `{"authorization_endpoint":"auth","token_endpoint":"token"}`,
getResponse: `{"authorization_endpoint":"https://auth.stackit.cloud/enpoint","token_endpoint":"https://token.stackit.cloud/endpoint"}`,
isValid: false,
expected: nil,
},
{
name: "missing_authorization",
getFails: true,
getResponse: `{"issuer":"issuer","token_endpoint":"token"}`,
getResponse: `{"issuer":"https://issuer.stackit.cloud/endpoint","token_endpoint":"https://token.stackit.cloud/endpoint"}`,
isValid: false,
expected: nil,
},
{
name: "missing_token",
getFails: true,
getResponse: `{"issuer":"issuer","authorization_endpoint":"auth"}`,
getResponse: `{"issuer":"https://issuer.stackit.cloud/endpoint","authorization_endpoint":"https://auth.stackit.cloud/enpoint"}`,
isValid: false,
expected: nil,
},
Expand Down
9 changes: 9 additions & 0 deletions internal/pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/base64"
"fmt"
"net/url"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -82,11 +83,19 @@ func ValidateURLDomain(value string) error {
if err != nil {
return fmt.Errorf("parse url: %w", err)
}

urlHost := urlStruct.Hostname()
if urlHost == "" {
return fmt.Errorf("bad url")
}

allowedSchemes := []string{
"https",
}
if !slices.Contains(allowedSchemes, urlStruct.Scheme) {
return fmt.Errorf("unsupported protocol: %s", urlStruct.Scheme)
}

allowedUrlDomain := viper.GetString(config.AllowedUrlDomainKey)

if !strings.HasSuffix(urlHost, allowedUrlDomain) {
Expand Down
15 changes: 15 additions & 0 deletions internal/pkg/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ func TestValidateURLDomain(t *testing.T) {
input: "",
isValid: false,
},
{
name: "invalid protocol",
input: "http://example.stackit.cloud",
isValid: false,
},
{
name: "no protocol",
input: "example.stackit.cloud",
isValid: false,
},
{
name: "valid endpoint",
input: "https://service-account.api.stackit.cloud/token",
isValid: true,
},
}

for _, tt := range tests {
Expand Down
Loading