* chore: remove unreachable dead code across the codebase Remove ~50,000 lines of unreachable code identified by static analysis. Major removals: - weed/filer/redis_lua: entire unused Redis Lua filer store implementation - weed/wdclient/net2, resource_pool: unused connection/resource pool packages - weed/plugin/worker/lifecycle: unused lifecycle plugin worker - weed/s3api: unused S3 policy templates, presigned URL IAM, streaming copy, multipart IAM, key rotation, and various SSE helper functions - weed/mq/kafka: unused partition mapping, compression, schema, and protocol functions - weed/mq/offset: unused SQL storage and migration code - weed/worker: unused registry, task, and monitoring functions - weed/query: unused SQL engine, parquet scanner, and type functions - weed/shell: unused EC proportional rebalance functions - weed/storage/erasure_coding/distribution: unused distribution analysis functions - Individual unreachable functions removed from 150+ files across admin, credential, filer, iam, kms, mount, mq, operation, pb, s3api, server, shell, storage, topology, and util packages * fix(s3): reset shared memory store in IAM test to prevent flaky failure TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey was flaky because the MemoryStore credential backend is a singleton registered via init(). Earlier tests that create anonymous identities pollute the shared store, causing LookupAnonymous() to unexpectedly return true. Fix by calling Reset() on the memory store before the test runs. * style: run gofmt on changed files * fix: restore KMS functions used by integration tests * fix(plugin): prevent panic on send to closed worker session channel The Plugin.sendToWorker method could panic with "send on closed channel" when a worker disconnected while a message was being sent. The race was between streamSession.close() closing the outgoing channel and sendToWorker writing to it concurrently. Add a done channel to streamSession that is closed before the outgoing channel, and check it in sendToWorker's select to safely detect closed sessions without panicking.
700 lines
17 KiB
Go
700 lines
17 KiB
Go
package postgres
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/md5"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/seaweedfs/seaweedfs/weed/glog"
|
|
"github.com/seaweedfs/seaweedfs/weed/query/engine"
|
|
"github.com/seaweedfs/seaweedfs/weed/util/version"
|
|
)
|
|
|
|
// PostgreSQL protocol constants
|
|
const (
|
|
// Protocol versions
|
|
PG_PROTOCOL_VERSION_3 = 196608 // PostgreSQL 3.0 protocol (0x00030000)
|
|
PG_SSL_REQUEST = 80877103 // SSL request (0x04d2162f)
|
|
PG_GSSAPI_REQUEST = 80877104 // GSSAPI request (0x04d21630)
|
|
|
|
// Message types from client
|
|
PG_MSG_STARTUP = 0x00
|
|
PG_MSG_QUERY = 'Q'
|
|
PG_MSG_PARSE = 'P'
|
|
PG_MSG_BIND = 'B'
|
|
PG_MSG_EXECUTE = 'E'
|
|
PG_MSG_DESCRIBE = 'D'
|
|
PG_MSG_CLOSE = 'C'
|
|
PG_MSG_FLUSH = 'H'
|
|
PG_MSG_SYNC = 'S'
|
|
PG_MSG_TERMINATE = 'X'
|
|
PG_MSG_PASSWORD = 'p'
|
|
|
|
// Response types to client
|
|
PG_RESP_AUTH_OK = 'R'
|
|
PG_RESP_BACKEND_KEY = 'K'
|
|
PG_RESP_PARAMETER = 'S'
|
|
PG_RESP_READY = 'Z'
|
|
PG_RESP_COMMAND = 'C'
|
|
PG_RESP_DATA_ROW = 'D'
|
|
PG_RESP_ROW_DESC = 'T'
|
|
PG_RESP_PARSE_COMPLETE = '1'
|
|
PG_RESP_BIND_COMPLETE = '2'
|
|
PG_RESP_CLOSE_COMPLETE = '3'
|
|
PG_RESP_ERROR = 'E'
|
|
PG_RESP_NOTICE = 'N'
|
|
|
|
// Transaction states
|
|
PG_TRANS_IDLE = 'I'
|
|
PG_TRANS_INTRANS = 'T'
|
|
PG_TRANS_ERROR = 'E'
|
|
|
|
// Authentication methods
|
|
AUTH_OK = 0
|
|
AUTH_CLEAR = 3
|
|
AUTH_MD5 = 5
|
|
AUTH_TRUST = 10
|
|
|
|
// PostgreSQL data types
|
|
PG_TYPE_BOOL = 16
|
|
PG_TYPE_BYTEA = 17
|
|
PG_TYPE_INT8 = 20
|
|
PG_TYPE_INT4 = 23
|
|
PG_TYPE_TEXT = 25
|
|
PG_TYPE_FLOAT4 = 700
|
|
PG_TYPE_FLOAT8 = 701
|
|
PG_TYPE_VARCHAR = 1043
|
|
PG_TYPE_TIMESTAMP = 1114
|
|
PG_TYPE_JSON = 114
|
|
PG_TYPE_JSONB = 3802
|
|
|
|
// Default values
|
|
DEFAULT_POSTGRES_PORT = 5432
|
|
)
|
|
|
|
// Authentication method type
|
|
type AuthMethod int
|
|
|
|
const (
|
|
AuthTrust AuthMethod = iota
|
|
AuthPassword
|
|
AuthMD5
|
|
)
|
|
|
|
// PostgreSQL server configuration
|
|
type PostgreSQLServerConfig struct {
|
|
Host string
|
|
Port int
|
|
AuthMethod AuthMethod
|
|
Users map[string]string
|
|
TLSConfig *tls.Config
|
|
MaxConns int
|
|
IdleTimeout time.Duration
|
|
StartupTimeout time.Duration // Timeout for client startup handshake
|
|
Database string
|
|
}
|
|
|
|
// PostgreSQL server
|
|
type PostgreSQLServer struct {
|
|
config *PostgreSQLServerConfig
|
|
listener net.Listener
|
|
sqlEngine *engine.SQLEngine
|
|
sessions map[uint32]*PostgreSQLSession
|
|
sessionMux sync.RWMutex
|
|
shutdown chan struct{}
|
|
wg sync.WaitGroup
|
|
nextConnID uint32
|
|
}
|
|
|
|
// PostgreSQL session
|
|
type PostgreSQLSession struct {
|
|
conn net.Conn
|
|
reader *bufio.Reader
|
|
writer *bufio.Writer
|
|
authenticated bool
|
|
username string
|
|
database string
|
|
parameters map[string]string
|
|
preparedStmts map[string]*PreparedStatement
|
|
portals map[string]*Portal
|
|
transactionState byte
|
|
processID uint32
|
|
secretKey uint32
|
|
created time.Time
|
|
lastActivity time.Time
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
// Prepared statement
|
|
type PreparedStatement struct {
|
|
Name string
|
|
Query string
|
|
ParamTypes []uint32
|
|
Fields []FieldDescription
|
|
}
|
|
|
|
// Portal (cursor)
|
|
type Portal struct {
|
|
Name string
|
|
Statement string
|
|
Parameters [][]byte
|
|
Suspended bool
|
|
}
|
|
|
|
// Field description
|
|
type FieldDescription struct {
|
|
Name string
|
|
TableOID uint32
|
|
AttrNum int16
|
|
TypeOID uint32
|
|
TypeSize int16
|
|
TypeMod int32
|
|
Format int16
|
|
}
|
|
|
|
// NewPostgreSQLServer creates a new PostgreSQL protocol server
|
|
func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
|
|
if config.Port <= 0 {
|
|
config.Port = DEFAULT_POSTGRES_PORT
|
|
}
|
|
if config.Host == "" {
|
|
config.Host = "localhost"
|
|
}
|
|
if config.Database == "" {
|
|
config.Database = "default"
|
|
}
|
|
if config.MaxConns <= 0 {
|
|
config.MaxConns = 100
|
|
}
|
|
if config.IdleTimeout <= 0 {
|
|
config.IdleTimeout = time.Hour
|
|
}
|
|
if config.StartupTimeout <= 0 {
|
|
config.StartupTimeout = 30 * time.Second
|
|
}
|
|
|
|
// Create SQL engine (now uses CockroachDB parser for PostgreSQL compatibility)
|
|
sqlEngine := engine.NewSQLEngine(masterAddr)
|
|
|
|
server := &PostgreSQLServer{
|
|
config: config,
|
|
sqlEngine: sqlEngine,
|
|
sessions: make(map[uint32]*PostgreSQLSession),
|
|
shutdown: make(chan struct{}),
|
|
nextConnID: 1,
|
|
}
|
|
|
|
return server, nil
|
|
}
|
|
|
|
// Start begins listening for PostgreSQL connections
|
|
func (s *PostgreSQLServer) Start() error {
|
|
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
|
|
|
|
var listener net.Listener
|
|
var err error
|
|
|
|
if s.config.TLSConfig != nil {
|
|
listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
|
|
glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
|
|
} else {
|
|
listener, err = net.Listen("tcp", addr)
|
|
glog.Infof("PostgreSQL Server listening on %s", addr)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
|
|
}
|
|
|
|
s.listener = listener
|
|
|
|
// Start accepting connections
|
|
s.wg.Add(1)
|
|
go s.acceptConnections()
|
|
|
|
// Start cleanup routine
|
|
s.wg.Add(1)
|
|
go s.cleanupSessions()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully shuts down the PostgreSQL server
|
|
func (s *PostgreSQLServer) Stop() error {
|
|
close(s.shutdown)
|
|
|
|
if s.listener != nil {
|
|
s.listener.Close()
|
|
}
|
|
|
|
// Close all sessions
|
|
s.sessionMux.Lock()
|
|
for _, session := range s.sessions {
|
|
session.close()
|
|
}
|
|
s.sessions = make(map[uint32]*PostgreSQLSession)
|
|
s.sessionMux.Unlock()
|
|
|
|
s.wg.Wait()
|
|
glog.Infof("PostgreSQL Server stopped")
|
|
return nil
|
|
}
|
|
|
|
// acceptConnections handles incoming PostgreSQL connections
|
|
func (s *PostgreSQLServer) acceptConnections() {
|
|
defer s.wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
}
|
|
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Check connection limit
|
|
s.sessionMux.RLock()
|
|
sessionCount := len(s.sessions)
|
|
s.sessionMux.RUnlock()
|
|
|
|
if sessionCount >= s.config.MaxConns {
|
|
glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
|
|
s.config.MaxConns, conn.RemoteAddr())
|
|
conn.Close()
|
|
continue
|
|
}
|
|
|
|
s.wg.Add(1)
|
|
go s.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
// handleConnection processes a single PostgreSQL connection
|
|
func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
|
|
defer s.wg.Done()
|
|
defer conn.Close()
|
|
|
|
// Generate unique connection ID
|
|
connID := s.generateConnectionID()
|
|
secretKey := s.generateSecretKey()
|
|
|
|
// Create session
|
|
session := &PostgreSQLSession{
|
|
conn: conn,
|
|
reader: bufio.NewReader(conn),
|
|
writer: bufio.NewWriter(conn),
|
|
authenticated: false,
|
|
database: s.config.Database,
|
|
parameters: make(map[string]string),
|
|
preparedStmts: make(map[string]*PreparedStatement),
|
|
portals: make(map[string]*Portal),
|
|
transactionState: PG_TRANS_IDLE,
|
|
processID: connID,
|
|
secretKey: secretKey,
|
|
created: time.Now(),
|
|
lastActivity: time.Now(),
|
|
}
|
|
|
|
// Register session
|
|
s.sessionMux.Lock()
|
|
s.sessions[connID] = session
|
|
s.sessionMux.Unlock()
|
|
|
|
// Clean up on exit
|
|
defer func() {
|
|
s.sessionMux.Lock()
|
|
delete(s.sessions, connID)
|
|
s.sessionMux.Unlock()
|
|
}()
|
|
|
|
glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
|
|
|
|
// Handle startup
|
|
err := s.handleStartup(session)
|
|
if err != nil {
|
|
// Handle common disconnection scenarios more gracefully
|
|
if strings.Contains(err.Error(), "client disconnected") {
|
|
glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
|
|
} else if strings.Contains(err.Error(), "timeout") {
|
|
glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
|
|
} else {
|
|
glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Handle messages
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Set read timeout
|
|
conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
|
|
err := s.handleMessage(session)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
|
|
} else {
|
|
glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
session.lastActivity = time.Now()
|
|
}
|
|
}
|
|
|
|
// handleStartup processes the PostgreSQL startup sequence
|
|
func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
|
|
// Set a startup timeout to prevent hanging connections
|
|
startupTimeout := s.config.StartupTimeout
|
|
session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
|
|
defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
|
|
|
|
for {
|
|
// Read startup message length
|
|
length := make([]byte, 4)
|
|
_, err := io.ReadFull(session.reader, length)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
// Client disconnected during startup - this is common for health checks
|
|
return fmt.Errorf("client disconnected during startup handshake")
|
|
}
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
|
|
}
|
|
return fmt.Errorf("failed to read message length during startup: %v", err)
|
|
}
|
|
|
|
msgLength := binary.BigEndian.Uint32(length) - 4
|
|
if msgLength > 10000 { // Reasonable limit for startup messages
|
|
return fmt.Errorf("startup message too large: %d bytes", msgLength)
|
|
}
|
|
|
|
// Read startup message content
|
|
msg := make([]byte, msgLength)
|
|
_, err = io.ReadFull(session.reader, msg)
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return fmt.Errorf("client disconnected while reading startup message")
|
|
}
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
return fmt.Errorf("startup message read timeout")
|
|
}
|
|
return fmt.Errorf("failed to read startup message: %v", err)
|
|
}
|
|
|
|
// Parse protocol version
|
|
protocolVersion := binary.BigEndian.Uint32(msg[0:4])
|
|
|
|
switch protocolVersion {
|
|
case PG_SSL_REQUEST:
|
|
// Reject SSL request - send 'N' to indicate SSL not supported
|
|
_, err = session.conn.Write([]byte{'N'})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to reject SSL request: %v", err)
|
|
}
|
|
// Continue loop to read the actual startup message
|
|
continue
|
|
|
|
case PG_GSSAPI_REQUEST:
|
|
// Reject GSSAPI request - send 'N' to indicate GSSAPI not supported
|
|
_, err = session.conn.Write([]byte{'N'})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to reject GSSAPI request: %v", err)
|
|
}
|
|
// Continue loop to read the actual startup message
|
|
continue
|
|
|
|
case PG_PROTOCOL_VERSION_3:
|
|
// This is the actual startup message, break out of loop
|
|
break
|
|
|
|
default:
|
|
return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
|
|
}
|
|
|
|
// Parse parameters
|
|
params := strings.Split(string(msg[4:]), "\x00")
|
|
for i := 0; i < len(params)-1; i += 2 {
|
|
if params[i] == "user" {
|
|
session.username = params[i+1]
|
|
} else if params[i] == "database" {
|
|
session.database = params[i+1]
|
|
}
|
|
session.parameters[params[i]] = params[i+1]
|
|
}
|
|
|
|
// Break out of the main loop - we have the startup message
|
|
break
|
|
}
|
|
|
|
// Handle authentication
|
|
err := s.handleAuthentication(session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send parameter status messages
|
|
err = s.sendParameterStatus(session, "server_version", fmt.Sprintf("%s (SeaweedFS)", version.VERSION_NUMBER))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = s.sendParameterStatus(session, "server_encoding", "UTF8")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = s.sendParameterStatus(session, "client_encoding", "UTF8")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = s.sendParameterStatus(session, "integer_datetimes", "on")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send backend key data
|
|
err = s.sendBackendKeyData(session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send ready for query
|
|
err = s.sendReadyForQuery(session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
session.authenticated = true
|
|
return nil
|
|
}
|
|
|
|
// handleAuthentication processes authentication
|
|
func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
|
|
switch s.config.AuthMethod {
|
|
case AuthTrust:
|
|
return s.sendAuthenticationOk(session)
|
|
case AuthPassword:
|
|
return s.handlePasswordAuth(session)
|
|
case AuthMD5:
|
|
return s.handleMD5Auth(session)
|
|
default:
|
|
return fmt.Errorf("unsupported authentication method")
|
|
}
|
|
}
|
|
|
|
// sendAuthenticationOk sends authentication OK message
|
|
func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
|
|
msg := make([]byte, 9)
|
|
msg[0] = PG_RESP_AUTH_OK
|
|
binary.BigEndian.PutUint32(msg[1:5], 8)
|
|
binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
|
|
|
|
_, err := session.writer.Write(msg)
|
|
if err == nil {
|
|
err = session.writer.Flush()
|
|
}
|
|
return err
|
|
}
|
|
|
|
// handlePasswordAuth handles clear password authentication
|
|
func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
|
|
// Send password request
|
|
msg := make([]byte, 9)
|
|
msg[0] = PG_RESP_AUTH_OK
|
|
binary.BigEndian.PutUint32(msg[1:5], 8)
|
|
binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
|
|
|
|
_, err := session.writer.Write(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = session.writer.Flush()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Read password response
|
|
msgType := make([]byte, 1)
|
|
_, err = io.ReadFull(session.reader, msgType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if msgType[0] != PG_MSG_PASSWORD {
|
|
return fmt.Errorf("expected password message, got %c", msgType[0])
|
|
}
|
|
|
|
length := make([]byte, 4)
|
|
_, err = io.ReadFull(session.reader, length)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
msgLength := binary.BigEndian.Uint32(length) - 4
|
|
password := make([]byte, msgLength)
|
|
_, err = io.ReadFull(session.reader, password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify password
|
|
expectedPassword, exists := s.config.Users[session.username]
|
|
if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
|
|
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
|
|
}
|
|
|
|
return s.sendAuthenticationOk(session)
|
|
}
|
|
|
|
// handleMD5Auth handles MD5 password authentication
|
|
func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
|
|
// Generate salt
|
|
salt := make([]byte, 4)
|
|
_, err := rand.Read(salt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send MD5 request
|
|
msg := make([]byte, 13)
|
|
msg[0] = PG_RESP_AUTH_OK
|
|
binary.BigEndian.PutUint32(msg[1:5], 12)
|
|
binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
|
|
copy(msg[9:13], salt)
|
|
|
|
_, err = session.writer.Write(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = session.writer.Flush()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Read password response
|
|
msgType := make([]byte, 1)
|
|
_, err = io.ReadFull(session.reader, msgType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if msgType[0] != PG_MSG_PASSWORD {
|
|
return fmt.Errorf("expected password message, got %c", msgType[0])
|
|
}
|
|
|
|
length := make([]byte, 4)
|
|
_, err = io.ReadFull(session.reader, length)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
msgLength := binary.BigEndian.Uint32(length) - 4
|
|
response := make([]byte, msgLength)
|
|
_, err = io.ReadFull(session.reader, response)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify MD5 hash
|
|
expectedPassword, exists := s.config.Users[session.username]
|
|
if !exists {
|
|
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
|
|
}
|
|
|
|
// Calculate expected hash: md5(md5(password + username) + salt)
|
|
inner := md5.Sum([]byte(expectedPassword + session.username))
|
|
expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
|
|
|
|
if string(response[:len(response)-1]) != expected { // Remove null terminator
|
|
return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
|
|
}
|
|
|
|
return s.sendAuthenticationOk(session)
|
|
}
|
|
|
|
// generateConnectionID generates a unique connection ID
|
|
func (s *PostgreSQLServer) generateConnectionID() uint32 {
|
|
s.sessionMux.Lock()
|
|
defer s.sessionMux.Unlock()
|
|
id := s.nextConnID
|
|
s.nextConnID++
|
|
return id
|
|
}
|
|
|
|
// generateSecretKey generates a secret key for the connection
|
|
func (s *PostgreSQLServer) generateSecretKey() uint32 {
|
|
key := make([]byte, 4)
|
|
rand.Read(key)
|
|
return binary.BigEndian.Uint32(key)
|
|
}
|
|
|
|
// close marks the session as closed
|
|
func (s *PostgreSQLSession) close() {
|
|
s.mutex.Lock()
|
|
defer s.mutex.Unlock()
|
|
if s.conn != nil {
|
|
s.conn.Close()
|
|
s.conn = nil
|
|
}
|
|
}
|
|
|
|
// cleanupSessions periodically cleans up idle sessions
|
|
func (s *PostgreSQLServer) cleanupSessions() {
|
|
defer s.wg.Done()
|
|
|
|
ticker := time.NewTicker(time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
case <-ticker.C:
|
|
s.cleanupIdleSessions()
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanupIdleSessions removes sessions that have been idle too long
|
|
func (s *PostgreSQLServer) cleanupIdleSessions() {
|
|
now := time.Now()
|
|
|
|
s.sessionMux.Lock()
|
|
defer s.sessionMux.Unlock()
|
|
|
|
for id, session := range s.sessions {
|
|
if now.Sub(session.lastActivity) > s.config.IdleTimeout {
|
|
glog.Infof("Closing idle PostgreSQL session %d", id)
|
|
session.close()
|
|
delete(s.sessions, id)
|
|
}
|
|
}
|
|
}
|