fix: port in SNI address when using domainName instead of IP for master (#8500)
This commit is contained in:
@@ -86,6 +86,23 @@ func (sa ServerAddress) ToGrpcAddress() string {
|
|||||||
return ServerToGrpcAddress(string(sa))
|
return ServerToGrpcAddress(string(sa))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToHost returns the host part only, without any port information.
|
||||||
|
func (sa ServerAddress) ToHost() string {
|
||||||
|
httpAddr := sa.ToHttpAddress()
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(httpAddr)
|
||||||
|
if err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: if parsing fails, it's likely a host without a port.
|
||||||
|
// Handle bracketed IPv6 (e.g., "[::1]" without port) by trimming brackets.
|
||||||
|
if strings.HasPrefix(httpAddr, "[") && strings.HasSuffix(httpAddr, "]") {
|
||||||
|
return httpAddr[1 : len(httpAddr)-1]
|
||||||
|
}
|
||||||
|
return httpAddr
|
||||||
|
}
|
||||||
|
|
||||||
// LookUp may return an error for some records along with successful lookups - make sure you do not
|
// LookUp may return an error for some records along with successful lookups - make sure you do not
|
||||||
// discard `addresses` even if `err == nil`
|
// discard `addresses` even if `err == nil`
|
||||||
func (r ServerSrvAddress) LookUp() (addresses []ServerAddress, err error) {
|
func (r ServerSrvAddress) LookUp() (addresses []ServerAddress, err error) {
|
||||||
|
|||||||
@@ -35,6 +35,64 @@ func TestServerAddresses_ToAddressMapOrSrv_shouldHandleIPPortList(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerAddress_ToHost(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
address ServerAddress
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "hostname with port",
|
||||||
|
address: ServerAddress("master.example.com:9333"),
|
||||||
|
expected: "master.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 with port",
|
||||||
|
address: ServerAddress("192.168.1.1:9333"),
|
||||||
|
expected: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 with port",
|
||||||
|
address: ServerAddress("[2001:db8::1]:9333"),
|
||||||
|
expected: "2001:db8::1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname without port",
|
||||||
|
address: ServerAddress("master.example.com"),
|
||||||
|
expected: "master.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hostname with port.grpcPort",
|
||||||
|
address: ServerAddress("master.example.com:443.10443"),
|
||||||
|
expected: "master.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 with port.grpcPort",
|
||||||
|
address: ServerAddress("192.168.1.1:8080.18080"),
|
||||||
|
expected: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 with port.grpcPort",
|
||||||
|
address: ServerAddress("[2001:db8::1]:8080.18080"),
|
||||||
|
expected: "2001:db8::1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bracketed IPv6 without port",
|
||||||
|
address: ServerAddress("[2001:db8::1]"),
|
||||||
|
expected: "2001:db8::1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := tc.address.ToHost()
|
||||||
|
if got != tc.expected {
|
||||||
|
t.Errorf("ServerAddress(%q).ToHost() = %q, want %q", tc.address, got, tc.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIPv6ServerAddressFormatting(t *testing.T) {
|
func TestIPv6ServerAddressFormatting(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package security
|
package security
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -12,6 +14,7 @@ import (
|
|||||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
|
"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
|
||||||
"google.golang.org/grpc/security/advancedtls"
|
"google.golang.org/grpc/security/advancedtls"
|
||||||
@@ -24,6 +27,37 @@ type Authenticator struct {
|
|||||||
AllowedCommonNames map[string]bool
|
AllowedCommonNames map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SNIStrippingTransportCredentials wraps another TransportCredentials
|
||||||
|
// and strips the port from the authority in ClientHandshake to prevent
|
||||||
|
// advancedtls from using the full "host:port" as ServerName in SNI.
|
||||||
|
type SNIStrippingTransportCredentials struct {
|
||||||
|
creds credentials.TransportCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SNIStrippingTransportCredentials) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
host, _, err := net.SplitHostPort(authority)
|
||||||
|
if err == nil {
|
||||||
|
authority = host
|
||||||
|
}
|
||||||
|
return s.creds.ClientHandshake(ctx, authority, rawConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SNIStrippingTransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
return s.creds.ServerHandshake(rawConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SNIStrippingTransportCredentials) Info() credentials.ProtocolInfo {
|
||||||
|
return s.creds.Info()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SNIStrippingTransportCredentials) Clone() credentials.TransportCredentials {
|
||||||
|
return &SNIStrippingTransportCredentials{creds: s.creds.Clone()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SNIStrippingTransportCredentials) OverrideServerName(serverNameOverride string) error {
|
||||||
|
return s.creds.OverrideServerName(serverNameOverride)
|
||||||
|
}
|
||||||
|
|
||||||
func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption, grpc.ServerOption) {
|
func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption, grpc.ServerOption) {
|
||||||
if config == nil {
|
if config == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -151,7 +185,8 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption {
|
|||||||
glog.Warningf("advancedtls.NewClientCreds(%v) failed: %v", options, err)
|
glog.Warningf("advancedtls.NewClientCreds(%v) failed: %v", options, err)
|
||||||
return grpc.WithTransportCredentials(insecure.NewCredentials())
|
return grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
}
|
}
|
||||||
return grpc.WithTransportCredentials(ta)
|
wrapped := &SNIStrippingTransportCredentials{creds: ta}
|
||||||
|
return grpc.WithTransportCredentials(wrapped)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadClientTLSHTTP(clientCertFile string) *tls.Config {
|
func LoadClientTLSHTTP(clientCertFile string) *tls.Config {
|
||||||
|
|||||||
219
weed/security/tls_sni_test.go
Normal file
219
weed/security/tls_sni_test.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/security/advancedtls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateSelfSignedCert(t *testing.T) (tls.Certificate, *x509.CertPool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate CA key: %v", err)
|
||||||
|
}
|
||||||
|
caTemplate := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: "Test CA"},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
}
|
||||||
|
caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create CA cert: %v", err)
|
||||||
|
}
|
||||||
|
caCert, err := x509.ParseCertificate(caDER)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse CA cert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate leaf key: %v", err)
|
||||||
|
}
|
||||||
|
leafTemplate := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2),
|
||||||
|
Subject: pkix.Name{CommonName: "localhost"},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
||||||
|
}
|
||||||
|
leafDER, err := x509.CreateCertificate(rand.Reader, leafTemplate, caCert, &leafKey.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create leaf cert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
leafCert := tls.Certificate{
|
||||||
|
Certificate: [][]byte{leafDER},
|
||||||
|
PrivateKey: leafKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
pool.AddCert(caCert)
|
||||||
|
|
||||||
|
return leafCert, pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTLSListenerCapturingSNI(t *testing.T, cert tls.Certificate) (string, <-chan string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
sniChan := make(chan string, 1)
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
select {
|
||||||
|
case sniChan <- hello.ServerName:
|
||||||
|
close(sniChan)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.Listen: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { ln.Close() })
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
conn.Read(buf)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ln.Addr().String(), sniChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSNIStrippingCreds(t *testing.T, cert tls.Certificate, pool *x509.CertPool) credentials.TransportCredentials {
|
||||||
|
t.Helper()
|
||||||
|
clientCreds, err := advancedtls.NewClientCreds(&advancedtls.Options{
|
||||||
|
IdentityOptions: advancedtls.IdentityCertificateOptions{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
},
|
||||||
|
RootOptions: advancedtls.RootCertificateOptions{
|
||||||
|
RootCertificates: pool,
|
||||||
|
},
|
||||||
|
VerificationType: advancedtls.CertVerification,
|
||||||
|
AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) {
|
||||||
|
return &advancedtls.PostHandshakeVerificationResults{}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewClientCreds: %v", err)
|
||||||
|
}
|
||||||
|
return &SNIStrippingTransportCredentials{creds: clientCreds}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSNI_HostnameStripsPort(t *testing.T) {
|
||||||
|
cert, pool := generateSelfSignedCert(t)
|
||||||
|
wrapped := newSNIStrippingCreds(t, cert, pool)
|
||||||
|
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// gRPC passes "host:port" as authority; SNI wrapper strips the port
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _, err = wrapped.ClientHandshake(ctx, "localhost:"+portFromAddr(addr), conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ClientHandshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sni := <-sniChan:
|
||||||
|
if sni != "localhost" {
|
||||||
|
t.Errorf("SNI = %q, want %q", sni, "localhost")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for SNI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSNI_IPAddressEmptySNI(t *testing.T) {
|
||||||
|
cert, pool := generateSelfSignedCert(t)
|
||||||
|
wrapped := newSNIStrippingCreds(t, cert, pool)
|
||||||
|
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// RFC 6066: IP addresses MUST NOT be sent as SNI; Go's TLS sends empty ServerName for IPs
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _, err = wrapped.ClientHandshake(ctx, "127.0.0.1:"+portFromAddr(addr), conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ClientHandshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sni := <-sniChan:
|
||||||
|
if sni != "" {
|
||||||
|
t.Errorf("SNI = %q, want empty string", sni)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for SNI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSNI_IPv6AddressEmptySNI(t *testing.T) {
|
||||||
|
cert, pool := generateSelfSignedCert(t)
|
||||||
|
wrapped := newSNIStrippingCreds(t, cert, pool)
|
||||||
|
addr, sniChan := startTLSListenerCapturingSNI(t, cert)
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _, err = wrapped.ClientHandshake(ctx, "[::1]:"+portFromAddr(addr), conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ClientHandshake: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case sni := <-sniChan:
|
||||||
|
if sni != "" {
|
||||||
|
t.Errorf("SNI = %q, want empty string", sni)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for SNI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func portFromAddr(addr string) string {
|
||||||
|
_, port, _ := net.SplitHostPort(addr)
|
||||||
|
return port
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user