From 8dc06d8c2a3c9eced14fb4a3840894f77d8eece8 Mon Sep 17 00:00:00 2001 From: Andrew Metcalf Date: Wed, 8 Oct 2014 16:00:10 -0700 Subject: [PATCH] Default TLS ServerName to the host in the DSN. A TLS configuration must either have a ServerName or specify InsecureSkipVerify. In most cases, the ServerName value will match the host part of the address in the DSN. This change updates the DSN parser to default the ServerName to the host value provided unless InsecureSkipVerify is specified. --- AUTHORS | 1 + utils.go | 8 ++++++++ utils_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/AUTHORS b/AUTHORS index f0b070246..d851a0477 100644 --- a/AUTHORS +++ b/AUTHORS @@ -34,3 +34,4 @@ Xiuming Chen Barracuda Networks, Inc. Google Inc. +Stripe Inc. diff --git a/utils.go b/utils.go index 56f1b082e..98dfc6f5e 100644 --- a/utils.go +++ b/utils.go @@ -16,6 +16,7 @@ import ( "errors" "fmt" "io" + "net" "net/url" "strings" "time" @@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) { if strings.ToLower(value) == "skip-verify" { cfg.tls = &tls.Config{InsecureSkipVerify: true} } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.addr) + if err == nil { + tlsConfig.ServerName = host + } + } + cfg.tls = tlsConfig } else { return fmt.Errorf("Invalid value / unknown config name: %s", value) diff --git a/utils_test.go b/utils_test.go index 6e50b09b9..0855374b7 100644 --- a/utils_test.go +++ b/utils_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "crypto/tls" "encoding/binary" "fmt" "testing" @@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) { } } +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "user:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := parseDSN(tst) + if err == nil { + t.Errorf("Invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = parseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = parseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } + + DeregisterTLSConfig("utils_test") +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs()