diff --git a/pkg/modules/auth/openid/openid.go b/pkg/modules/auth/openid/openid.go index 8cd169bd..8c62664e 100644 --- a/pkg/modules/auth/openid/openid.go +++ b/pkg/modules/auth/openid/openid.go @@ -59,6 +59,7 @@ type claims struct { Email string `json:"email"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` + Nickname string `json:"nickname"` } func init() { @@ -138,9 +139,40 @@ func HandleCallback(c echo.Context) error { return handler.HandleHTTPError(err, c) } - if cl.Email == "" { - log.Errorf("Claim does not contain an email address for provider %s", provider.Name) - return handler.HandleHTTPError(&user.ErrNoOpenIDEmailProvided{}, c) + if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" { + info, err := provider.OpenIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token)) + if err != nil { + log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err) + return handler.HandleHTTPError(err, c) + } + + cl2 := &claims{} + err = info.Claims(cl2) + if err != nil { + log.Errorf("Error parsing userinfo claims for provider %s: %v", provider.Name, err) + return handler.HandleHTTPError(err, c) + } + + if cl.Email == "" { + cl.Email = cl2.Email + } + + if cl.Name == "" { + cl.Name = cl2.Name + } + + if cl.PreferredUsername == "" { + cl.PreferredUsername = cl2.PreferredUsername + } + + if cl.PreferredUsername == "" && cl2.Nickname != "" { + cl.PreferredUsername = cl2.Nickname + } + + if cl.Email == "" { + log.Errorf("Claim does not contain an email address for provider %s", provider.Name) + return handler.HandleHTTPError(&user.ErrNoOpenIDEmailProvided{}, c) + } } s := db.NewSession() diff --git a/pkg/modules/auth/openid/providers.go b/pkg/modules/auth/openid/providers.go index 2a2c2f61..5ed8ce59 100644 --- a/pkg/modules/auth/openid/providers.go +++ b/pkg/modules/auth/openid/providers.go @@ -79,7 +79,7 @@ func GetAllProviders() (providers []*Provider, err error) { func GetProvider(key string) (provider *Provider, err error) { var p interface{} p, exists, err := keyvalue.Get("openid_provider_" + key) - if exists { + if !exists { _, err = GetAllProviders() // This will put all providers in cache if err != nil { return nil, err