Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,15 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa
}

flowConfig := &discovery.OAuthFlowConfig{
ClientID: remoteAuthFlags.RemoteAuthClientID,
ClientSecret: clientSecret,
AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL,
TokenURL: remoteAuthFlags.RemoteAuthTokenURL,
Scopes: remoteAuthFlags.RemoteAuthScopes,
CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort,
Timeout: remoteAuthFlags.RemoteAuthTimeout,
SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser,
ClientID: remoteAuthFlags.RemoteAuthClientID,
ClientSecret: clientSecret,
AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL,
TokenURL: remoteAuthFlags.RemoteAuthTokenURL,
Scopes: remoteAuthFlags.RemoteAuthScopes,
CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort,
Timeout: remoteAuthFlags.RemoteAuthTimeout,
SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser,
IssuerProvided: remoteAuthFlags.RemoteAuthIssuer != "", // Issuer was explicitly provided
}

result, err := discovery.PerformOAuthFlow(ctx, remoteAuthFlags.RemoteAuthIssuer, flowConfig)
Expand All @@ -325,14 +326,15 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa

// Perform OAuth flow with discovered configuration
flowConfig := &discovery.OAuthFlowConfig{
ClientID: remoteAuthFlags.RemoteAuthClientID,
ClientSecret: clientSecret,
AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL,
TokenURL: remoteAuthFlags.RemoteAuthTokenURL,
Scopes: remoteAuthFlags.RemoteAuthScopes,
CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort,
Timeout: remoteAuthFlags.RemoteAuthTimeout,
SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser,
ClientID: remoteAuthFlags.RemoteAuthClientID,
ClientSecret: clientSecret,
AuthorizeURL: remoteAuthFlags.RemoteAuthAuthorizeURL,
TokenURL: remoteAuthFlags.RemoteAuthTokenURL,
Scopes: remoteAuthFlags.RemoteAuthScopes,
CallbackPort: remoteAuthFlags.RemoteAuthCallbackPort,
Timeout: remoteAuthFlags.RemoteAuthTimeout,
SkipBrowser: remoteAuthFlags.RemoteAuthSkipBrowser,
IssuerProvided: false, // Issuer was derived from WWW-Authenticate header
}

result, err := discovery.PerformOAuthFlow(ctx, authInfo.Realm, flowConfig)
Expand Down
21 changes: 11 additions & 10 deletions pkg/auth/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,16 @@ func DeriveIssuerFromURL(remoteURL string) string {

// OAuthFlowConfig contains configuration for performing OAuth flows
type OAuthFlowConfig struct {
ClientID string
ClientSecret string
AuthorizeURL string // Manual OAuth endpoint (optional)
TokenURL string // Manual OAuth endpoint (optional)
Scopes []string
CallbackPort int
Timeout time.Duration
SkipBrowser bool
OAuthParams map[string]string
ClientID string
ClientSecret string
AuthorizeURL string // Manual OAuth endpoint (optional)
TokenURL string // Manual OAuth endpoint (optional)
Scopes []string
CallbackPort int
Timeout time.Duration
SkipBrowser bool
OAuthParams map[string]string
IssuerProvided bool // Whether the issuer was explicitly provided (not derived from URL)
}

// OAuthFlowResult contains the result of an OAuth flow
Expand All @@ -272,7 +273,7 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi
var oauthConfig *oauth.Config
var err error
if shouldDynamicallyRegisterClient(config) {
discoveredDoc, err := oauth.DiscoverOIDCEndpoints(ctx, issuer)
discoveredDoc, err := oauth.DiscoverOIDCEndpoints(ctx, issuer, config.IssuerProvided)
if err != nil {
return nil, fmt.Errorf("failed to discover registration endpoint: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/auth/oauth/dynamic_registration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestDiscoverOIDCEndpointsWithRegistration(t *testing.T) {
issuer = server.URL
}

result, err := DiscoverOIDCEndpoints(context.Background(), issuer)
result, err := DiscoverOIDCEndpoints(context.Background(), issuer, false)

if tt.expectedError {
assert.Error(t, err)
Expand Down
22 changes: 15 additions & 7 deletions pkg/auth/oauth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ type httpClient interface {
}

// DiscoverOIDCEndpoints discovers OAuth endpoints from an OIDC issuer
func DiscoverOIDCEndpoints(ctx context.Context, issuer string) (*OIDCDiscoveryDocument, error) {
return discoverOIDCEndpointsWithClient(ctx, issuer, nil)
// Uses flexible issuer validation to support cases where issuer is derived from URL
Copy link
Preview

Copilot AI Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The comment should end with a period for consistency with other function documentation.

Suggested change
// Uses flexible issuer validation to support cases where issuer is derived from URL
// Uses flexible issuer validation to support cases where issuer is derived from URL.

Copilot uses AI. Check for mistakes.

func DiscoverOIDCEndpoints(ctx context.Context, issuer string, validateIssuerMatch bool) (*OIDCDiscoveryDocument, error) {
return discoverOIDCEndpointsWithClient(ctx, issuer, nil, validateIssuerMatch)
}

// discoverOIDCEndpointsWithClient discovers OAuth endpoints from an OIDC issuer with a custom HTTP client (private for testing)
func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client httpClient) (*OIDCDiscoveryDocument, error) {
func discoverOIDCEndpointsWithClient(
ctx context.Context,
issuer string,
client httpClient,
validateIssuerMatch bool,
) (*OIDCDiscoveryDocument, error) {
// Validate issuer URL
issuerURL, err := url.Parse(issuer)
if err != nil {
Expand Down Expand Up @@ -98,7 +104,7 @@ func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client
if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseSize)).Decode(&doc); err != nil {
return nil, fmt.Errorf("%s: unexpected response: %w", urlStr, err)
}
if err := validateOIDCDocument(&doc, issuer, oidc); err != nil {
if err := validateOIDCDocument(&doc, issuer, validateIssuerMatch, oidc); err != nil {
return nil, fmt.Errorf("%s: invalid metadata: %w", urlStr, err)
}
return &doc, nil
Expand All @@ -120,12 +126,14 @@ func discoverOIDCEndpointsWithClient(ctx context.Context, issuer string, client
}

// validateOIDCDocument validates the OIDC discovery document
func validateOIDCDocument(doc *OIDCDiscoveryDocument, expectedIssuer string, oidc bool) error {
func validateOIDCDocument(doc *OIDCDiscoveryDocument, expectedIssuer string, validateIssuerMatch bool, oidc bool) error {
if doc.Issuer == "" {
return fmt.Errorf("missing issuer")
}

if doc.Issuer != expectedIssuer {
// Only validate issuer match if explicitly requested
// This allows for cases where issuer is derived from URL and might not match exactly
if validateIssuerMatch && doc.Issuer != expectedIssuer {
return fmt.Errorf("issuer mismatch: expected %s, got %s", expectedIssuer, doc.Issuer)
}

Expand Down Expand Up @@ -184,7 +192,7 @@ func createOAuthConfigFromOIDCWithClient(
client httpClient,
) (*Config, error) {
// Discover OIDC endpoints
doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client)
doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client, true)
if err != nil {
return nil, fmt.Errorf("failed to discover OIDC endpoints: %w", err)
}
Expand Down
116 changes: 82 additions & 34 deletions pkg/auth/oauth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func testDiscoverOIDCEndpoints(
}

// Validate that we got the required fields
if err := validateOIDCDocument(&doc, issuer, true); err != nil {
if err := validateOIDCDocument(&doc, issuer, true, true); err != nil {
return nil, fmt.Errorf("invalid OIDC configuration: %w", err)
}

Expand Down Expand Up @@ -319,11 +319,12 @@ func TestDiscoverOIDCEndpoints(t *testing.T) {
func TestValidateOIDCDocument(t *testing.T) {
t.Parallel()
tests := []struct {
name string
doc *OIDCDiscoveryDocument
expectedIssuer string
expectError bool
errorMsg string
name string
doc *OIDCDiscoveryDocument
expectedIssuer string
validateIssuerMatch bool
expectError bool
errorMsg string
}{
{
name: "missing issuer",
Expand All @@ -332,21 +333,23 @@ func TestValidateOIDCDocument(t *testing.T) {
TokenEndpoint: "https://example.com/token",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "missing issuer",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "missing issuer",
},
{
name: "issuer mismatch",
name: "issuer mismatch (strict validation)",
doc: &OIDCDiscoveryDocument{
Issuer: "https://malicious.com",
AuthorizationEndpoint: "https://example.com/auth",
TokenEndpoint: "https://example.com/token",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "issuer mismatch",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "issuer mismatch",
},
{
name: "missing authorization endpoint",
Expand All @@ -355,9 +358,10 @@ func TestValidateOIDCDocument(t *testing.T) {
TokenEndpoint: "https://example.com/token",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "missing authorization_endpoint",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "missing authorization_endpoint",
},
{
name: "missing token endpoint",
Expand All @@ -366,9 +370,10 @@ func TestValidateOIDCDocument(t *testing.T) {
AuthorizationEndpoint: "https://example.com/auth",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "missing token_endpoint",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "missing token_endpoint",
},
{
name: "missing JWKS URI",
Expand All @@ -377,9 +382,10 @@ func TestValidateOIDCDocument(t *testing.T) {
AuthorizationEndpoint: "https://example.com/auth",
TokenEndpoint: "https://example.com/token",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "missing jwks_uri",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "missing jwks_uri",
},
{
name: "invalid authorization endpoint URL",
Expand All @@ -389,9 +395,10 @@ func TestValidateOIDCDocument(t *testing.T) {
TokenEndpoint: "https://example.com/token",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "invalid authorization_endpoint",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "invalid authorization_endpoint",
},
{
name: "non-HTTPS endpoint (security check)",
Expand All @@ -401,9 +408,10 @@ func TestValidateOIDCDocument(t *testing.T) {
TokenEndpoint: "https://example.com/token",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
expectError: true,
errorMsg: "invalid authorization_endpoint",
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: true,
errorMsg: "invalid authorization_endpoint",
},
{
name: "valid document",
Expand All @@ -414,8 +422,9 @@ func TestValidateOIDCDocument(t *testing.T) {
JWKSURI: "https://example.com/jwks",
UserinfoEndpoint: "https://example.com/userinfo",
},
expectedIssuer: "https://example.com",
expectError: false,
expectedIssuer: "https://example.com",
validateIssuerMatch: true,
expectError: false,
},
{
name: "localhost endpoints allowed",
Expand All @@ -425,15 +434,54 @@ func TestValidateOIDCDocument(t *testing.T) {
TokenEndpoint: "http://localhost:8080/token",
JWKSURI: "http://localhost:8080/jwks",
},
expectedIssuer: "http://localhost:8080",
expectError: false,
expectedIssuer: "http://localhost:8080",
validateIssuerMatch: true,
expectError: false,
},
// Flexible validation test cases
{
name: "flexible validation allows issuer mismatch",
doc: &OIDCDiscoveryDocument{
Issuer: "https://auth.example.com", // Different from expected
AuthorizationEndpoint: "https://auth.example.com/auth",
TokenEndpoint: "https://auth.example.com/token",
JWKSURI: "https://auth.example.com/jwks",
},
expectedIssuer: "https://example.com", // Expected issuer
validateIssuerMatch: false, // Flexible validation
expectError: false, // Should NOT error with flexible validation
},
{
name: "flexible validation allows derived issuer mismatch (Neon scenario)",
doc: &OIDCDiscoveryDocument{
Issuer: "https://auth.neon.com", // Different from derived issuer
AuthorizationEndpoint: "https://auth.neon.com/oauth/authorize",
TokenEndpoint: "https://auth.neon.com/oauth/token",
JWKSURI: "https://auth.neon.com/.well-known/jwks.json",
},
expectedIssuer: "https://api.neon.com", // Derived from URL
validateIssuerMatch: false, // Flexible validation
expectError: false, // Should NOT error with flexible validation
},
{
name: "flexible validation still requires issuer field",
doc: &OIDCDiscoveryDocument{
// Missing issuer field
AuthorizationEndpoint: "https://example.com/auth",
TokenEndpoint: "https://example.com/token",
JWKSURI: "https://example.com/jwks",
},
expectedIssuer: "https://example.com",
validateIssuerMatch: false, // Flexible validation
expectError: true,
errorMsg: "missing issuer",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := validateOIDCDocument(tt.doc, tt.expectedIssuer, true)
err := validateOIDCDocument(tt.doc, tt.expectedIssuer, tt.validateIssuerMatch, true)

if tt.expectError {
require.Error(t, err)
Expand Down Expand Up @@ -1054,7 +1102,7 @@ func TestDiscoverOIDCEndpoints_Production(t *testing.T) {
},
}
}
doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client)
doc, err := discoverOIDCEndpointsWithClient(ctx, issuer, client, true)

if tt.expectError {
require.Error(t, err)
Expand Down
20 changes: 11 additions & 9 deletions pkg/runner/remote_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func (h *RemoteAuthHandler) Authenticate(ctx context.Context, remoteURL string)
// Handle OAuth authentication
if authInfo.Type == "OAuth" {
issuer := h.config.Issuer
issuerProvided := issuer != ""
if issuer == "" {
issuer = discovery.DeriveIssuerFromURL(remoteURL)
}
Expand All @@ -51,15 +52,16 @@ func (h *RemoteAuthHandler) Authenticate(ctx context.Context, remoteURL string)

// Create OAuth flow config from RemoteAuthConfig
flowConfig := &discovery.OAuthFlowConfig{
ClientID: h.config.ClientID,
ClientSecret: h.config.ClientSecret,
AuthorizeURL: h.config.AuthorizeURL,
TokenURL: h.config.TokenURL,
Scopes: h.config.Scopes,
CallbackPort: h.config.CallbackPort,
Timeout: h.config.Timeout,
SkipBrowser: h.config.SkipBrowser,
OAuthParams: h.config.OAuthParams,
ClientID: h.config.ClientID,
ClientSecret: h.config.ClientSecret,
AuthorizeURL: h.config.AuthorizeURL,
TokenURL: h.config.TokenURL,
Scopes: h.config.Scopes,
CallbackPort: h.config.CallbackPort,
Timeout: h.config.Timeout,
SkipBrowser: h.config.SkipBrowser,
OAuthParams: h.config.OAuthParams,
IssuerProvided: issuerProvided,
}

result, err := discovery.PerformOAuthFlow(ctx, issuer, flowConfig)
Expand Down
Loading