fix: port in SNI address when using domainName instead of IP for master (#8500)
This commit is contained in:
@@ -86,6 +86,23 @@ func (sa ServerAddress) ToGrpcAddress() string {
|
||||
return ServerToGrpcAddress(string(sa))
|
||||
}
|
||||
|
||||
// ToHost returns the host part only, without any port information.
|
||||
func (sa ServerAddress) ToHost() string {
|
||||
httpAddr := sa.ToHttpAddress()
|
||||
|
||||
host, _, err := net.SplitHostPort(httpAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
|
||||
// Fallback: if parsing fails, it's likely a host without a port.
|
||||
// Handle bracketed IPv6 (e.g., "[::1]" without port) by trimming brackets.
|
||||
if strings.HasPrefix(httpAddr, "[") && strings.HasSuffix(httpAddr, "]") {
|
||||
return httpAddr[1 : len(httpAddr)-1]
|
||||
}
|
||||
return httpAddr
|
||||
}
|
||||
|
||||
// LookUp may return an error for some records along with successful lookups - make sure you do not
|
||||
// discard `addresses` even if `err == nil`
|
||||
func (r ServerSrvAddress) LookUp() (addresses []ServerAddress, err error) {
|
||||
|
||||
@@ -35,6 +35,64 @@ func TestServerAddresses_ToAddressMapOrSrv_shouldHandleIPPortList(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerAddress_ToHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
address ServerAddress
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "hostname with port",
|
||||
address: ServerAddress("master.example.com:9333"),
|
||||
expected: "master.example.com",
|
||||
},
|
||||
{
|
||||
name: "IPv4 with port",
|
||||
address: ServerAddress("192.168.1.1:9333"),
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 with port",
|
||||
address: ServerAddress("[2001:db8::1]:9333"),
|
||||
expected: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
name: "hostname without port",
|
||||
address: ServerAddress("master.example.com"),
|
||||
expected: "master.example.com",
|
||||
},
|
||||
{
|
||||
name: "hostname with port.grpcPort",
|
||||
address: ServerAddress("master.example.com:443.10443"),
|
||||
expected: "master.example.com",
|
||||
},
|
||||
{
|
||||
name: "IPv4 with port.grpcPort",
|
||||
address: ServerAddress("192.168.1.1:8080.18080"),
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 with port.grpcPort",
|
||||
address: ServerAddress("[2001:db8::1]:8080.18080"),
|
||||
expected: "2001:db8::1",
|
||||
},
|
||||
{
|
||||
name: "bracketed IPv6 without port",
|
||||
address: ServerAddress("[2001:db8::1]"),
|
||||
expected: "2001:db8::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.address.ToHost()
|
||||
if got != tc.expected {
|
||||
t.Errorf("ServerAddress(%q).ToHost() = %q, want %q", tc.address, got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPv6ServerAddressFormatting(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
|
||||
"google.golang.org/grpc/security/advancedtls"
|
||||
@@ -24,6 +27,37 @@ type Authenticator struct {
|
||||
AllowedCommonNames map[string]bool
|
||||
}
|
||||
|
||||
// SNIStrippingTransportCredentials wraps another TransportCredentials
|
||||
// and strips the port from the authority in ClientHandshake to prevent
|
||||
// advancedtls from using the full "host:port" as ServerName in SNI.
|
||||
type SNIStrippingTransportCredentials struct {
|
||||
creds credentials.TransportCredentials
|
||||
}
|
||||
|
||||
func (s *SNIStrippingTransportCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
host, _, err := net.SplitHostPort(authority)
|
||||
if err == nil {
|
||||
authority = host
|
||||
}
|
||||
return s.creds.ClientHandshake(ctx, authority, rawConn)
|
||||
}
|
||||
|
||||
func (s *SNIStrippingTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return s.creds.ServerHandshake(rawConn)
|
||||
}
|
||||
|
||||
func (s *SNIStrippingTransportCredentials) Info() credentials.ProtocolInfo {
|
||||
return s.creds.Info()
|
||||
}
|
||||
|
||||
func (s *SNIStrippingTransportCredentials) Clone() credentials.TransportCredentials {
|
||||
return &SNIStrippingTransportCredentials{creds: s.creds.Clone()}
|
||||
}
|
||||
|
||||
func (s *SNIStrippingTransportCredentials) OverrideServerName(serverNameOverride string) error {
|
||||
return s.creds.OverrideServerName(serverNameOverride)
|
||||
}
|
||||
|
||||
func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption, grpc.ServerOption) {
|
||||
if config == nil {
|
||||
return nil, nil
|
||||
@@ -151,7 +185,8 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption {
|
||||
glog.Warningf("advancedtls.NewClientCreds(%v) failed: %v", options, err)
|
||||
return grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
}
|
||||
return grpc.WithTransportCredentials(ta)
|
||||
wrapped := &SNIStrippingTransportCredentials{creds: ta}
|
||||
return grpc.WithTransportCredentials(wrapped)
|
||||
}
|
||||
|
||||
func LoadClientTLSHTTP(clientCertFile string) *tls.Config {
|
||||
|
||||
219
weed/security/tls_sni_test.go
Normal file
219
weed/security/tls_sni_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/security/advancedtls"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert(t *testing.T) (tls.Certificate, *x509.CertPool) {
|
||||
t.Helper()
|
||||
|
||||
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate CA key: %v", err)
|
||||
}
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("create CA cert: %v", err)
|
||||
}
|
||||
caCert, err := x509.ParseCertificate(caDER)
|
||||
if err != nil {
|
||||
t.Fatalf("parse CA cert: %v", err)
|
||||
}
|
||||
|
||||
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate leaf key: %v", err)
|
||||
}
|
||||
leafTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
||||
}
|
||||
leafDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caCert, &leafKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("create leaf cert: %v", err)
|
||||
}
|
||||
|
||||
leafCert := tls.Certificate{
|
||||
Certificate: [][]byte{leafDER},
|
||||
PrivateKey: leafKey,
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(caCert)
|
||||
|
||||
return leafCert, pool
|
||||
}
|
||||
|
||||
func startTLSListenerCapturingSNI(t *testing.T, cert tls.Certificate) (string, <-chan string) {
|
||||
t.Helper()
|
||||
|
||||
sniChan := make(chan string, 1)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
select {
|
||||
case sniChan <- hello.ServerName:
|
||||
close(sniChan)
|
||||
default:
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("tls.Listen: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { ln.Close() })
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1)
|
||||
conn.Read(buf)
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), sniChan
|
||||
}
|
||||
|
||||
func newSNIStrippingCreds(t *testing.T, cert tls.Certificate, pool *x509.CertPool) credentials.TransportCredentials {
|
||||
t.Helper()
|
||||
clientCreds, err := advancedtls.NewClientCreds(&advancedtls.Options{
|
||||
IdentityOptions: advancedtls.IdentityCertificateOptions{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
},
|
||||
RootOptions: advancedtls.RootCertificateOptions{
|
||||
RootCertificates: pool,
|
||||
},
|
||||
VerificationType: advancedtls.CertVerification,
|
||||
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
|
||||
return &advancedtls.PostHandshakeVerificationResults{}, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientCreds: %v", err)
|
||||
}
|
||||
return &SNIStrippingTransportCredentials{creds: clientCreds}
|
||||
}
|
||||
|
||||
func TestSNI_HostnameStripsPort(t *testing.T) {
|
||||
cert, pool := generateSelfSignedCert(t)
|
||||
wrapped := newSNIStrippingCreds(t, cert, pool)
|
||||
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// gRPC passes "host:port" as authority; SNI wrapper strips the port
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _, err = wrapped.ClientHandshake(ctx, "localhost:"+portFromAddr(addr), conn)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientHandshake: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case sni := <-sniChan:
|
||||
if sni != "localhost" {
|
||||
t.Errorf("SNI = %q, want %q", sni, "localhost")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for SNI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSNI_IPAddressEmptySNI(t *testing.T) {
|
||||
cert, pool := generateSelfSignedCert(t)
|
||||
wrapped := newSNIStrippingCreds(t, cert, pool)
|
||||
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// RFC 6066: IP addresses MUST NOT be sent as SNI; Go's TLS sends empty ServerName for IPs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _, err = wrapped.ClientHandshake(ctx, "127.0.0.1:"+portFromAddr(addr), conn)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientHandshake: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case sni := <-sniChan:
|
||||
if sni != "" {
|
||||
t.Errorf("SNI = %q, want empty string", sni)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for SNI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSNI_IPv6AddressEmptySNI(t *testing.T) {
|
||||
cert, pool := generateSelfSignedCert(t)
|
||||
wrapped := newSNIStrippingCreds(t, cert, pool)
|
||||
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _, err = wrapped.ClientHandshake(ctx, "[::1]:"+portFromAddr(addr), conn)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientHandshake: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case sni := <-sniChan:
|
||||
if sni != "" {
|
||||
t.Errorf("SNI = %q, want empty string", sni)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for SNI")
|
||||
}
|
||||
}
|
||||
|
||||
func portFromAddr(addr string) string {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
return port
|
||||
}
|
||||
Reference in New Issue
Block a user