Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions consent/strategy_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ func TestStrategyLoginConsentNext(t *testing.T) {

subject := "aeneas-rekkas"
c := createDefaultClient(t)
now := 1723546027 // Unix timestamps must round-trip through Hydra without converting to floats or similar
testhelpers.NewLoginConsentUI(t, reg.Config(),
acceptLoginHandler(t, subject, &hydra.AcceptOAuth2LoginRequest{
Remember: pointerx.Bool(true),
Expand All @@ -297,8 +298,14 @@ func TestStrategyLoginConsentNext(t *testing.T) {
Remember: pointerx.Bool(true),
GrantScope: []string{"openid"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{"foo": "bar"},
IdToken: map[string]interface{}{"bar": "baz"},
AccessToken: map[string]interface{}{
"foo": "bar",
"ts1": now,
},
IdToken: map[string]interface{}{
"bar": "baz",
"ts2": now,
},
},
}))

Expand All @@ -314,12 +321,14 @@ func TestStrategyLoginConsentNext(t *testing.T) {
require.NoError(t, err)

claims := testhelpers.IntrospectToken(t, conf, token.AccessToken, adminTS)
assert.Equal(t, "bar", claims.Get("ext.foo").String(), "%s", claims.Raw)
assert.Equalf(t, `"bar"`, claims.Get("ext.foo").Raw, "%s", claims.Raw) // Raw rather than .Int() or .Value() to verify the exact JSON payload
assert.Equalf(t, "1723546027", claims.Get("ext.ts1").Raw, "%s", claims.Raw) // must round-trip as integer

idClaims := testhelpers.DecodeIDToken(t, token)
assert.Equal(t, "baz", idClaims.Get("bar").String(), "%s", idClaims.Raw)
assert.Equalf(t, `"baz"`, idClaims.Get("bar").Raw, "%s", idClaims.Raw) // Raw rather than .Int() or .Value() to verify the exact JSON payload
assert.Equalf(t, "1723546027", idClaims.Get("ts2").Raw, "%s", idClaims.Raw) // must round-trip as integer
sid = idClaims.Get("sid").String()
assert.NotNil(t, sid)
assert.NotEmpty(t, sid)
}

t.Run("perform first flow", run)
Expand All @@ -334,21 +343,28 @@ func TestStrategyLoginConsentNext(t *testing.T) {
assert.Empty(t, pointerx.StringR(res.Client.ClientSecret))
return hydra.AcceptOAuth2LoginRequest{
Subject: subject,
Context: map[string]interface{}{"foo": "bar"},
Context: map[string]interface{}{"xyz": "abc"},
}
}),
checkAndAcceptConsentHandler(t, adminClient, func(t *testing.T, res *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
checkAndAcceptConsentHandler(t, adminClient, func(t *testing.T, req *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
require.NoError(t, err)
assert.True(t, *res.Skip)
assert.Equal(t, sid, *res.LoginSessionId)
assert.Equal(t, subject, *res.Subject)
assert.Empty(t, pointerx.StringR(res.Client.ClientSecret))
assert.True(t, *req.Skip)
assert.Equal(t, sid, *req.LoginSessionId)
assert.Equal(t, subject, *req.Subject)
assert.Empty(t, pointerx.StringR(req.Client.ClientSecret))
assert.Equal(t, map[string]interface{}{"xyz": "abc"}, req.Context)
return hydra.AcceptOAuth2ConsentRequest{
Remember: pointerx.Bool(true),
GrantScope: []string{"openid"},
Session: &hydra.AcceptOAuth2ConsentRequestSession{
AccessToken: map[string]interface{}{"foo": "bar"},
IdToken: map[string]interface{}{"bar": "baz"},
AccessToken: map[string]interface{}{
"foo": "bar",
"ts1": now,
},
IdToken: map[string]interface{}{
"bar": "baz",
"ts2": now,
},
},
}
}))
Expand Down
3 changes: 2 additions & 1 deletion oauth2/.snapshots/TestUnmarshalSession-v1.11.8.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"amr": [],
"c_hash": "",
"ext": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"headers": {
Expand Down
3 changes: 2 additions & 1 deletion oauth2/.snapshots/TestUnmarshalSession-v1.11.9.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
"amr": [],
"c_hash": "",
"ext": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"headers": {
Expand Down
3 changes: 2 additions & 1 deletion oauth2/fixtures/v1.11.8-session.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"AuthenticationMethodsReferences": [],
"CodeHash": "",
"Extra": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"Headers": {
Expand Down
3 changes: 2 additions & 1 deletion oauth2/fixtures/v1.11.9-session.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"amr": [],
"c_hash": "",
"ext": {
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d"
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027
}
},
"headers": {
Expand Down
10 changes: 5 additions & 5 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ func (h *Handler) getOidcUserInfo(w http.ResponseWriter, r *http.Request) {
interim["jti"] = uuid.New()
interim["iat"] = time.Now().Unix()

keyID, err := h.r.OpenIDJWTStrategy().GetPublicKeyID(r.Context())
keyID, err := h.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
h.r.Writer().WriteError(w, r, err)
return
Expand Down Expand Up @@ -727,7 +727,7 @@ type revokeOAuth2Token struct {
// default: errorOAuth2
func (h *Handler) revokeOAuth2Token(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
events.Trace(r.Context(), events.AccessTokenRevoked)
events.Trace(ctx, events.AccessTokenRevoked)

err := h.r.OAuth2Provider().NewRevocationRequest(ctx, r)
if err != nil {
Expand Down Expand Up @@ -980,13 +980,13 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
}
session.ClientID = accessRequest.GetClient().GetID()
session.KID = accessTokenKeyID
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(r.Context()).String()
session.DefaultSession.Claims.Issuer = h.c.IssuerURL(ctx).String()
session.DefaultSession.Claims.IssuedAt = time.Now().UTC()

scopes := accessRequest.GetRequestedScopes()

// Added for compatibility with MITREid
if h.c.GrantAllClientCredentialsScopesPerDefault(r.Context()) && len(scopes) == 0 {
if h.c.GrantAllClientCredentialsScopesPerDefault(ctx) && len(scopes) == 0 {
for _, scope := range accessRequest.GetClient().GetScopes() {
accessRequest.GrantScope(scope)
}
Expand Down Expand Up @@ -1089,7 +1089,7 @@ func (h *Handler) oAuth2Authorize(w http.ResponseWriter, r *http.Request, _ http
}

var accessTokenKeyID string
if h.c.AccessTokenStrategy(r.Context(), client.AccessTokenStrategySource(authorizeRequest.GetClient())) == "jwt" {
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(authorizeRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
x.LogError(r, err, h.r.Logger())
Expand Down
28 changes: 15 additions & 13 deletions oauth2/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
package oauth2

import (
"bytes"
"context"
"encoding/json"
"time"

jjson "github.com/go-jose/go-jose/v3/json"
"github.com/mohae/deepcopy"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/mohae/deepcopy"

"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"

"github.com/ory/x/logrusx"
"github.com/ory/x/stringslice"
)
Expand Down Expand Up @@ -60,33 +59,33 @@ func NewSessionWithCustomClaims(ctx context.Context, p *config.DefaultProvider,
}

func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer {
//a slice of claims that are reserved and should not be overridden
var reservedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}
// a slice of claims that are reserved and should not be overridden
reservedClaims := []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}

//remove any reserved claims from the custom claims
// remove any reserved claims from the custom claims
allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool {
return stringslice.Has(reservedClaims, s)
})

//our new extra map which will be added to the jwt
var topLevelExtraWithMirrorExt = map[string]interface{}{}
// our new extra map which will be added to the jwt
topLevelExtraWithMirrorExt := map[string]interface{}{}

//setting every allowed claim top level in jwt with respective value
// setting every allowed claim top level in jwt with respective value
for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved {
if cl, ok := s.Extra[allowedClaim]; ok {
topLevelExtraWithMirrorExt[allowedClaim] = cl
}
}

//for every other claim that was already reserved and for mirroring, add original extra under "ext"
// for every other claim that was already reserved and for mirroring, add original extra under "ext"
if s.MirrorTopLevelClaims {
topLevelExtraWithMirrorExt["ext"] = s.Extra
}

claims := &jwt.JWTClaims{
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
//set our custom extra map as claims.Extra
// set our custom extra map as claims.Extra
Extra: topLevelExtraWithMirrorExt,
ExpiresAt: s.GetExpiresAt(fosite.AccessToken),
IssuedAt: time.Now(),
Expand Down Expand Up @@ -185,8 +184,11 @@ func (s *Session) UnmarshalJSON(original []byte) (err error) {
}
}

// https://github.com/go-jose/go-jose/issues/144
dec := jjson.NewDecoder(bytes.NewReader(transformed))
dec.SetNumberType(jjson.UnmarshalIntOrFloat)
type t Session
if err := json.Unmarshal(transformed, (*t)(s)); err != nil {
if err := dec.Decode((*t)(s)); err != nil {
return errors.WithStack(err)
}

Expand Down
5 changes: 3 additions & 2 deletions oauth2/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func TestUnmarshalSession(t *testing.T) {
AuthenticationMethodsReferences: []string{},
CodeHash: "",
Extra: map[string]interface{}{
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"sid": "177e1f44-a1e9-415c-bfa3-8b62280b182d",
"timestamp": 1723546027,
},
},
Headers: &jwt.Headers{Extra: map[string]interface{}{
Expand Down Expand Up @@ -85,7 +86,7 @@ func TestUnmarshalSession(t *testing.T) {
snapshotx.SnapshotTExcept(t, &actual, nil)
})

t.Run("v1.11.9", func(t *testing.T) {
t.Run("v1.11.9" /* and later versions */, func(t *testing.T) {
var actual Session
require.NoError(t, json.Unmarshal(v1119Session, &actual))
assertx.EqualAsJSON(t, expect, &actual)
Expand Down
Loading