Skip to content

Commit 884cd37

Browse files
authored
Implement AuthenticationHandler for custom auth mechanisms (#1072)
* Implement `AuthenticationHandler` for custom auth mechanisms * Fix xor * Fall through when the cached password doesn't match * Run golangci-lint * Leverage `slices.Contains` * Unexport `Credential` methods and add comment * Add a comment on empty passwords * Add `AUTH_CLEAR_METHOD` to `isAuthMethodSupported`
1 parent 8a40370 commit 884cd37

14 files changed

+357
-198
lines changed

driver/driver_options_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
266266
}
267267

268268
func createMockServer(t *testing.T) *testServer {
269-
inMemProvider := server.NewInMemoryProvider()
270-
require.NoError(t, inMemProvider.AddUser(*testUser, *testPassword))
269+
authHandler := server.NewInMemoryAuthenticationHandler()
270+
require.NoError(t, authHandler.AddUser(*testUser, *testPassword))
271271
defaultServer := server.NewDefaultServer()
272272

273273
l, err := net.Listen("tcp", "127.0.0.1:3307")
@@ -285,7 +285,7 @@ func createMockServer(t *testing.T) *testServer {
285285
}
286286

287287
go func() {
288-
co, err := s.NewCustomizedConn(conn, inMemProvider, handler)
288+
co, err := s.NewCustomizedConn(conn, authHandler, handler)
289289
if err != nil {
290290
return
291291
}

mysql/util.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ func CalcNativePassword(scramble, password []byte) []byte {
5656
return Xor(scrambleHash, stage1)
5757
}
5858

59-
// Xor modifies hash1 in-place with XOR against hash2
59+
// Xor returns a new slice with hash1 XOR hash2, wrapping hash2 if hash1 is longer.
6060
func Xor(hash1 []byte, hash2 []byte) []byte {
61-
l := min(len(hash1), len(hash2))
62-
for i := range l {
63-
hash1[i] ^= hash2[i]
61+
result := make([]byte, len(hash1))
62+
for i := range hash1 {
63+
result[i] = hash1[i] ^ hash2[i%len(hash2)]
6464
}
65-
return hash1
65+
return result
6666
}
6767

6868
// hash_stage1 = xor(reply, sha1(public_seed, hash_stage2))

server/auth.go

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err
3030
return c.serverConf.authProvider.Authenticate(c, authPluginName, clientAuthData)
3131
}
3232

33-
func (c *Conn) acquirePassword() error {
34-
if c.credential.Password != "" {
33+
func (c *Conn) acquireCredential() error {
34+
if len(c.credential.Passwords) > 0 {
3535
return nil
3636
}
37-
credential, found, err := c.credentialProvider.GetCredential(c.user)
37+
credential, found, err := c.authHandler.GetCredential(c.user)
3838
if err != nil {
3939
return err
4040
}
41-
if !found {
41+
if !found || len(credential.Passwords) == 0 {
4242
return mysql.NewDefaultError(mysql.ER_NO_SUCH_USER, c.user, c.RemoteAddr().String())
4343
}
4444
c.credential = credential
@@ -67,26 +67,32 @@ func scrambleValidation(cached, nonce, scramble []byte) bool {
6767

6868
func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error {
6969
if len(clientAuthData) == 0 {
70-
if credential.Password == "" {
70+
if credential.hasEmptyPassword() {
7171
return nil
7272
}
7373
return ErrAccessDeniedNoPassword
7474
}
7575

76-
password, err := mysql.DecodePasswordHex(c.credential.Password)
77-
if err != nil {
78-
return ErrAccessDenied
79-
}
80-
if mysql.CompareNativePassword(clientAuthData, password, c.salt) {
81-
return nil
76+
for _, password := range credential.Passwords {
77+
hash, err := credential.hashPassword(password)
78+
if err != nil {
79+
continue
80+
}
81+
decoded, err := mysql.DecodePasswordHex(hash)
82+
if err != nil {
83+
continue
84+
}
85+
if mysql.CompareNativePassword(clientAuthData, decoded, c.salt) {
86+
return nil
87+
}
8288
}
8389
return ErrAccessDenied
8490
}
8591

8692
func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error {
8793
// Empty passwords are not hashed, but sent as empty string
8894
if len(clientAuthData) == 0 {
89-
if credential.Password == "" {
95+
if credential.hasEmptyPassword() {
9096
return nil
9197
}
9298
return ErrAccessDeniedNoPassword
@@ -112,20 +118,26 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C
112118
clientAuthData = clientAuthData[:l-1]
113119
}
114120
}
115-
check, err := mysql.Check256HashingPassword([]byte(credential.Password), string(clientAuthData))
116-
if err != nil {
117-
return err
118-
}
119-
if check {
120-
return nil
121+
for _, password := range credential.Passwords {
122+
hash, err := credential.hashPassword(password)
123+
if err != nil {
124+
continue
125+
}
126+
check, err := mysql.Check256HashingPassword([]byte(hash), string(clientAuthData))
127+
if err != nil {
128+
continue
129+
}
130+
if check {
131+
return nil
132+
}
121133
}
122134
return ErrAccessDenied
123135
}
124136

125137
func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error {
126138
// Empty passwords are not hashed, but sent as empty string
127139
if len(clientAuthData) == 0 {
128-
if c.credential.Password == "" {
140+
if c.credential.hasEmptyPassword() {
129141
return nil
130142
}
131143
return ErrAccessDeniedNoPassword
@@ -139,10 +151,8 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error {
139151
// 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03
140152
return c.writeAuthMoreDataFastAuth()
141153
}
142-
143-
return ErrAccessDenied
144154
}
145-
// cache miss, do full auth
155+
// cache miss or validation failed, do full auth
146156
if err := c.writeAuthMoreDataFullAuth(); err != nil {
147157
return err
148158
}

server/auth_switch_response.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func (c *Conn) handleAuthSwitchResponse() error {
2424
}
2525

2626
func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
27-
if err := c.acquirePassword(); err != nil {
27+
if err := c.acquireCredential(); err != nil {
2828
return err
2929
}
3030
if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok {
@@ -72,15 +72,21 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error {
7272

7373
func (c *Conn) checkSha2CacheCredentials(clientAuthData []byte, credential Credential) error {
7474
if len(clientAuthData) == 0 {
75-
if credential.Password == "" {
75+
if credential.hasEmptyPassword() {
7676
return nil
7777
}
7878
return ErrAccessDeniedNoPassword
7979
}
8080

81-
match, err := auth.CheckHashingPassword([]byte(credential.Password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD)
82-
if match && err == nil {
83-
return nil
81+
for _, password := range credential.Passwords {
82+
hash, err := credential.hashPassword(password)
83+
if err != nil {
84+
continue
85+
}
86+
match, err := auth.CheckHashingPassword([]byte(hash), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD)
87+
if match && err == nil {
88+
return nil
89+
}
8490
}
8591
return ErrAccessDenied
8692
}

server/auth_switch_response_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func TestCheckSha2CacheCredentials_EmptyPassword(t *testing.T) {
3030
for _, tt := range tests {
3131
t.Run(tt.name, func(t *testing.T) {
3232
c := &Conn{
33-
credential: Credential{Password: tt.serverPassword},
33+
credential: Credential{Passwords: []string{tt.serverPassword}},
3434
}
3535
err := c.checkSha2CacheCredentials(tt.clientAuthData, c.credential)
3636
if tt.wantErr == nil {

server/auth_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func TestCompareNativePasswordAuthData_EmptyPassword(t *testing.T) {
3737
for _, tt := range tests {
3838
t.Run(tt.name, func(t *testing.T) {
3939
c := &Conn{
40-
credential: Credential{Password: tt.serverPassword},
40+
credential: Credential{Passwords: []string{tt.serverPassword}},
4141
}
4242
err := c.compareNativePasswordAuthData(tt.clientAuthData, c.credential)
4343
if tt.wantErr == nil {
@@ -73,7 +73,7 @@ func TestCompareSha256PasswordAuthData_EmptyPassword(t *testing.T) {
7373
for _, tt := range tests {
7474
t.Run(tt.name, func(t *testing.T) {
7575
c := &Conn{
76-
credential: Credential{Password: tt.serverPassword},
76+
credential: Credential{Passwords: []string{tt.serverPassword}},
7777
}
7878
err := c.compareSha256PasswordAuthData(tt.clientAuthData, c.credential)
7979
if tt.wantErr == nil {
@@ -109,7 +109,7 @@ func TestCompareCacheSha2PasswordAuthData_EmptyPassword(t *testing.T) {
109109
for _, tt := range tests {
110110
t.Run(tt.name, func(t *testing.T) {
111111
c := &Conn{
112-
credential: Credential{Password: tt.serverPassword},
112+
credential: Credential{Passwords: []string{tt.serverPassword}},
113113
}
114114
err := c.compareCacheSha2PasswordAuthData(tt.clientAuthData)
115115
if tt.wantErr == nil {

server/authentication_handler.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package server
2+
3+
import (
4+
"slices"
5+
"sync"
6+
7+
"github.com/go-mysql-org/go-mysql/mysql"
8+
"github.com/pingcap/errors"
9+
"github.com/pingcap/tidb/pkg/parser/auth"
10+
)
11+
12+
// AuthenticationHandler provides user credentials and authentication lifecycle hooks.
13+
//
14+
// # Important Note
15+
//
16+
// if the password in a third-party auth handler could be updated at runtime, we have to invalidate the caching
17+
// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'.
18+
type AuthenticationHandler interface {
19+
// GetCredential returns the user credential (supports multiple valid passwords per user).
20+
// Implementations must be safe for concurrent use.
21+
GetCredential(username string) (credential Credential, found bool, err error)
22+
23+
// OnAuthSuccess is called after successful authentication, before the OK packet.
24+
// Return an error to reject the connection (error will be sent to client instead of OK).
25+
// Return nil to proceed with sending the OK packet.
26+
OnAuthSuccess(conn *Conn) error
27+
28+
// OnAuthFailure is called after authentication fails, before the error packet.
29+
// This is informational only - the connection will be closed regardless.
30+
OnAuthFailure(conn *Conn, err error)
31+
}
32+
33+
func NewInMemoryAuthenticationHandler(defaultAuthMethod ...string) *InMemoryAuthenticationHandler {
34+
d := mysql.AUTH_CACHING_SHA2_PASSWORD
35+
if len(defaultAuthMethod) > 0 {
36+
d = defaultAuthMethod[0]
37+
}
38+
return &InMemoryAuthenticationHandler{
39+
userPool: sync.Map{},
40+
defaultAuthMethod: d,
41+
}
42+
}
43+
44+
// Credential holds authentication settings for a user.
45+
// Passwords contains all valid raw passwords for the user. They are hashed on demand during comparison.
46+
// If empty password authentication is allowed, Passwords must contain an empty string (e.g., []string{""})
47+
// rather than being a zero-length slice. A zero-length slice means no valid passwords are configured.
48+
type Credential struct {
49+
Passwords []string
50+
AuthPluginName string
51+
}
52+
53+
// hashPassword computes the password hash for a given password using the credential's auth plugin.
54+
func (c Credential) hashPassword(password string) (string, error) {
55+
if password == "" {
56+
return "", nil
57+
}
58+
59+
switch c.AuthPluginName {
60+
case mysql.AUTH_NATIVE_PASSWORD:
61+
return mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password))), nil
62+
63+
case mysql.AUTH_CACHING_SHA2_PASSWORD:
64+
return auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD), nil
65+
66+
case mysql.AUTH_SHA256_PASSWORD:
67+
return mysql.NewSha256PasswordHash(password)
68+
69+
case mysql.AUTH_CLEAR_PASSWORD:
70+
return password, nil
71+
72+
default:
73+
return "", errors.Errorf("unknown authentication plugin name '%s'", c.AuthPluginName)
74+
}
75+
}
76+
77+
// hasEmptyPassword returns true if any password in the credential is empty.
78+
func (c Credential) hasEmptyPassword() bool {
79+
return slices.Contains(c.Passwords, "")
80+
}
81+
82+
// InMemoryAuthenticationHandler implements AuthenticationHandler with in-memory credential storage.
83+
type InMemoryAuthenticationHandler struct {
84+
userPool sync.Map // username -> Credential
85+
defaultAuthMethod string
86+
}
87+
88+
func (h *InMemoryAuthenticationHandler) CheckUsername(username string) (found bool, err error) {
89+
_, ok := h.userPool.Load(username)
90+
return ok, nil
91+
}
92+
93+
func (h *InMemoryAuthenticationHandler) GetCredential(username string) (credential Credential, found bool, err error) {
94+
v, ok := h.userPool.Load(username)
95+
if !ok {
96+
return Credential{}, false, nil
97+
}
98+
c, valid := v.(Credential)
99+
if !valid {
100+
return Credential{}, true, errors.Errorf("invalid credential")
101+
}
102+
return c, true, nil
103+
}
104+
105+
func (h *InMemoryAuthenticationHandler) AddUser(username, password string, optionalAuthPluginName ...string) error {
106+
authPluginName := h.defaultAuthMethod
107+
if len(optionalAuthPluginName) > 0 {
108+
authPluginName = optionalAuthPluginName[0]
109+
}
110+
111+
if !isAuthMethodSupported(authPluginName) {
112+
return errors.Errorf("unknown authentication plugin name '%s'", authPluginName)
113+
}
114+
115+
h.userPool.Store(username, Credential{
116+
Passwords: []string{password},
117+
AuthPluginName: authPluginName,
118+
})
119+
return nil
120+
}
121+
122+
func (h *InMemoryAuthenticationHandler) OnAuthSuccess(conn *Conn) error {
123+
return nil
124+
}
125+
126+
func (h *InMemoryAuthenticationHandler) OnAuthFailure(conn *Conn, err error) {
127+
}

0 commit comments

Comments
 (0)