chore: remove ~50k lines of unreachable dead code (#8913)

* chore: remove unreachable dead code across the codebase

Remove ~50,000 lines of unreachable code identified by static analysis.

Major removals:
- weed/filer/redis_lua: entire unused Redis Lua filer store implementation
- weed/wdclient/net2, resource_pool: unused connection/resource pool packages
- weed/plugin/worker/lifecycle: unused lifecycle plugin worker
- weed/s3api: unused S3 policy templates, presigned URL IAM, streaming copy,
  multipart IAM, key rotation, and various SSE helper functions
- weed/mq/kafka: unused partition mapping, compression, schema, and protocol functions
- weed/mq/offset: unused SQL storage and migration code
- weed/worker: unused registry, task, and monitoring functions
- weed/query: unused SQL engine, parquet scanner, and type functions
- weed/shell: unused EC proportional rebalance functions
- weed/storage/erasure_coding/distribution: unused distribution analysis functions
- Individual unreachable functions removed from 150+ files across admin,
  credential, filer, iam, kms, mount, mq, operation, pb, s3api, server,
  shell, storage, topology, and util packages

* fix(s3): reset shared memory store in IAM test to prevent flaky failure

TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey was flaky because
the MemoryStore credential backend is a singleton registered via init().
Earlier tests that create anonymous identities pollute the shared store,
causing LookupAnonymous() to unexpectedly return true.

Fix by calling Reset() on the memory store before the test runs.

* style: run gofmt on changed files

* fix: restore KMS functions used by integration tests

* fix(plugin): prevent panic on send to closed worker session channel

The Plugin.sendToWorker method could panic with "send on closed channel"
when a worker disconnected while a message was being sent. The race was
between streamSession.close() closing the outgoing channel and sendToWorker
writing to it concurrently.

Add a done channel to streamSession that is closed before the outgoing
channel, and check it in sendToWorker's select to safely detect closed
sessions without panicking.
This commit is contained in:
Chris Lu
2026-04-03 16:04:27 -07:00
committed by GitHub
parent 8fad85aed7
commit 995dfc4d5d
264 changed files with 62 additions and 46027 deletions

View File

@@ -0,0 +1 @@
{"sessionId":"d6574c47-eafc-4a94-9dce-f9ffea22b53c","pid":10111,"acquiredAt":1775248373916}

5
.superset/config.json Normal file
View File

@@ -0,0 +1,5 @@
{
"setup": [],
"teardown": [],
"run": []
}

View File

@@ -2561,6 +2561,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "openssl-src"
version = "300.5.5+3.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f1787d533e03597a7934fd0a765f0d28e94ecc5fb7789f8053b1e699a56f709"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "openssl-sys" name = "openssl-sys"
version = "0.9.111" version = "0.9.111"
@@ -2569,6 +2578,7 @@ checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [ dependencies = [
"cc", "cc",
"libc", "libc",
"openssl-src",
"pkg-config", "pkg-config",
"vcpkg", "vcpkg",
] ]
@@ -4654,6 +4664,7 @@ dependencies = [
"memmap2", "memmap2",
"mime_guess", "mime_guess",
"multer", "multer",
"openssl",
"parking_lot 0.12.5", "parking_lot 0.12.5",
"pprof", "pprof",
"prometheus", "prometheus",

View File

@@ -447,11 +447,6 @@ type QueueStats = maintenance.QueueStats
type WorkerDetailsData = maintenance.WorkerDetailsData type WorkerDetailsData = maintenance.WorkerDetailsData
type WorkerPerformance = maintenance.WorkerPerformance type WorkerPerformance = maintenance.WorkerPerformance
// GetTaskIcon returns the icon CSS class for a task type from its UI provider
func GetTaskIcon(taskType MaintenanceTaskType) string {
return maintenance.GetTaskIcon(taskType)
}
// Status constants (these are still static) // Status constants (these are still static)
const ( const (
TaskStatusPending = maintenance.TaskStatusPending TaskStatusPending = maintenance.TaskStatusPending

View File

@@ -312,29 +312,6 @@ func (h *ClusterHandlers) ShowClusterFilers(w http.ResponseWriter, r *http.Reque
} }
} }
// ShowClusterBrokers renders the cluster message brokers page
func (h *ClusterHandlers) ShowClusterBrokers(w http.ResponseWriter, r *http.Request) {
// Get cluster brokers data
brokersData, err := h.adminServer.GetClusterBrokers()
if err != nil {
writeJSONError(w, http.StatusInternalServerError, "Failed to get cluster brokers: "+err.Error())
return
}
username := usernameOrDefault(r)
brokersData.Username = username
// Render HTML template
w.Header().Set("Content-Type", "text/html")
brokersComponent := app.ClusterBrokers(*brokersData)
viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context()))
layoutComponent := layout.Layout(viewCtx, brokersComponent)
if err := layoutComponent.Render(r.Context(), w); err != nil {
writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error())
return
}
}
// GetClusterTopology returns the cluster topology as JSON // GetClusterTopology returns the cluster topology as JSON
func (h *ClusterHandlers) GetClusterTopology(w http.ResponseWriter, r *http.Request) { func (h *ClusterHandlers) GetClusterTopology(w http.ResponseWriter, r *http.Request) {
topology, err := h.adminServer.GetClusterTopology() topology, err := h.adminServer.GetClusterTopology()

View File

@@ -78,34 +78,6 @@ func (h *MessageQueueHandlers) ShowTopics(w http.ResponseWriter, r *http.Request
} }
} }
// ShowSubscribers renders the message queue subscribers page
func (h *MessageQueueHandlers) ShowSubscribers(w http.ResponseWriter, r *http.Request) {
// Get subscribers data
subscribersData, err := h.adminServer.GetSubscribers()
if err != nil {
writeJSONError(w, http.StatusInternalServerError, "Failed to get subscribers: "+err.Error())
return
}
// Set username
username := dash.UsernameFromContext(r.Context())
if username == "" {
username = "admin"
}
subscribersData.Username = username
// Render HTML template
w.Header().Set("Content-Type", "text/html")
subscribersComponent := app.Subscribers(*subscribersData)
viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context()))
layoutComponent := layout.Layout(viewCtx, subscribersComponent)
err = layoutComponent.Render(r.Context(), w)
if err != nil {
writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error())
return
}
}
// ShowTopicDetails renders the topic details page // ShowTopicDetails renders the topic details page
func (h *MessageQueueHandlers) ShowTopicDetails(w http.ResponseWriter, r *http.Request) { func (h *MessageQueueHandlers) ShowTopicDetails(w http.ResponseWriter, r *http.Request) {
// Get topic parameters from URL // Get topic parameters from URL

View File

@@ -1,124 +0,0 @@
package maintenance
import (
"fmt"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
)
// VerifyProtobufConfig demonstrates that the protobuf configuration system is working
func VerifyProtobufConfig() error {
// Create configuration manager
configManager := NewMaintenanceConfigManager()
config := configManager.GetConfig()
// Verify basic configuration
if !config.Enabled {
return fmt.Errorf("expected config to be enabled by default")
}
if config.ScanIntervalSeconds != 30*60 {
return fmt.Errorf("expected scan interval to be 1800 seconds, got %d", config.ScanIntervalSeconds)
}
// Verify policy configuration
if config.Policy == nil {
return fmt.Errorf("expected policy to be configured")
}
if config.Policy.GlobalMaxConcurrent != 4 {
return fmt.Errorf("expected global max concurrent to be 4, got %d", config.Policy.GlobalMaxConcurrent)
}
// Verify task policies
vacuumPolicy := config.Policy.TaskPolicies["vacuum"]
if vacuumPolicy == nil {
return fmt.Errorf("expected vacuum policy to be configured")
}
if !vacuumPolicy.Enabled {
return fmt.Errorf("expected vacuum policy to be enabled")
}
// Verify typed configuration access
vacuumConfig := vacuumPolicy.GetVacuumConfig()
if vacuumConfig == nil {
return fmt.Errorf("expected vacuum config to be accessible")
}
if vacuumConfig.GarbageThreshold != 0.3 {
return fmt.Errorf("expected garbage threshold to be 0.3, got %f", vacuumConfig.GarbageThreshold)
}
// Verify helper functions work
if !IsTaskEnabled(config.Policy, "vacuum") {
return fmt.Errorf("expected vacuum task to be enabled via helper function")
}
maxConcurrent := GetMaxConcurrent(config.Policy, "vacuum")
if maxConcurrent != 2 {
return fmt.Errorf("expected vacuum max concurrent to be 2, got %d", maxConcurrent)
}
// Verify erasure coding configuration
ecPolicy := config.Policy.TaskPolicies["erasure_coding"]
if ecPolicy == nil {
return fmt.Errorf("expected EC policy to be configured")
}
ecConfig := ecPolicy.GetErasureCodingConfig()
if ecConfig == nil {
return fmt.Errorf("expected EC config to be accessible")
}
// Verify configurable EC fields only
if ecConfig.FullnessRatio <= 0 || ecConfig.FullnessRatio > 1 {
return fmt.Errorf("expected EC config to have valid fullness ratio (0-1), got %f", ecConfig.FullnessRatio)
}
return nil
}
// GetProtobufConfigSummary returns a summary of the current protobuf configuration
func GetProtobufConfigSummary() string {
configManager := NewMaintenanceConfigManager()
config := configManager.GetConfig()
summary := fmt.Sprintf("SeaweedFS Protobuf Maintenance Configuration:\n")
summary += fmt.Sprintf(" Enabled: %v\n", config.Enabled)
summary += fmt.Sprintf(" Scan Interval: %d seconds\n", config.ScanIntervalSeconds)
summary += fmt.Sprintf(" Max Retries: %d\n", config.MaxRetries)
summary += fmt.Sprintf(" Global Max Concurrent: %d\n", config.Policy.GlobalMaxConcurrent)
summary += fmt.Sprintf(" Task Policies: %d configured\n", len(config.Policy.TaskPolicies))
for taskType, policy := range config.Policy.TaskPolicies {
summary += fmt.Sprintf(" %s: enabled=%v, max_concurrent=%d\n",
taskType, policy.Enabled, policy.MaxConcurrent)
}
return summary
}
// CreateCustomConfig demonstrates creating a custom protobuf configuration
func CreateCustomConfig() *worker_pb.MaintenanceConfig {
return &worker_pb.MaintenanceConfig{
Enabled: true,
ScanIntervalSeconds: 60 * 60, // 1 hour
MaxRetries: 5,
Policy: &worker_pb.MaintenancePolicy{
GlobalMaxConcurrent: 8,
TaskPolicies: map[string]*worker_pb.TaskPolicy{
"custom_vacuum": {
Enabled: true,
MaxConcurrent: 4,
TaskConfig: &worker_pb.TaskPolicy_VacuumConfig{
VacuumConfig: &worker_pb.VacuumTaskConfig{
GarbageThreshold: 0.5,
MinVolumeAgeHours: 48,
},
},
},
},
},
}
}

View File

@@ -1,24 +1,9 @@
package maintenance package maintenance
import ( import (
"fmt"
"time"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
) )
// MaintenanceConfigManager handles protobuf-based configuration
type MaintenanceConfigManager struct {
config *worker_pb.MaintenanceConfig
}
// NewMaintenanceConfigManager creates a new config manager with defaults
func NewMaintenanceConfigManager() *MaintenanceConfigManager {
return &MaintenanceConfigManager{
config: DefaultMaintenanceConfigProto(),
}
}
// DefaultMaintenanceConfigProto returns default configuration as protobuf // DefaultMaintenanceConfigProto returns default configuration as protobuf
func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig { func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig {
return &worker_pb.MaintenanceConfig{ return &worker_pb.MaintenanceConfig{
@@ -34,253 +19,3 @@ func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig {
Policy: nil, Policy: nil,
} }
} }
// GetConfig returns the current configuration
func (mcm *MaintenanceConfigManager) GetConfig() *worker_pb.MaintenanceConfig {
return mcm.config
}
// Type-safe configuration accessors
// GetVacuumConfig returns vacuum-specific configuration for a task type
func (mcm *MaintenanceConfigManager) GetVacuumConfig(taskType string) *worker_pb.VacuumTaskConfig {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
if vacuumConfig := policy.GetVacuumConfig(); vacuumConfig != nil {
return vacuumConfig
}
}
// Return defaults if not configured
return &worker_pb.VacuumTaskConfig{
GarbageThreshold: 0.3,
MinVolumeAgeHours: 24,
}
}
// GetErasureCodingConfig returns EC-specific configuration for a task type
func (mcm *MaintenanceConfigManager) GetErasureCodingConfig(taskType string) *worker_pb.ErasureCodingTaskConfig {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
if ecConfig := policy.GetErasureCodingConfig(); ecConfig != nil {
return ecConfig
}
}
// Return defaults if not configured
return &worker_pb.ErasureCodingTaskConfig{
FullnessRatio: 0.95,
QuietForSeconds: 3600,
MinVolumeSizeMb: 100,
CollectionFilter: "",
}
}
// GetBalanceConfig returns balance-specific configuration for a task type
func (mcm *MaintenanceConfigManager) GetBalanceConfig(taskType string) *worker_pb.BalanceTaskConfig {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
if balanceConfig := policy.GetBalanceConfig(); balanceConfig != nil {
return balanceConfig
}
}
// Return defaults if not configured
return &worker_pb.BalanceTaskConfig{
ImbalanceThreshold: 0.2,
MinServerCount: 2,
}
}
// GetReplicationConfig returns replication-specific configuration for a task type
func (mcm *MaintenanceConfigManager) GetReplicationConfig(taskType string) *worker_pb.ReplicationTaskConfig {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
if replicationConfig := policy.GetReplicationConfig(); replicationConfig != nil {
return replicationConfig
}
}
// Return defaults if not configured
return &worker_pb.ReplicationTaskConfig{
TargetReplicaCount: 2,
}
}
// Typed convenience methods for getting task configurations
// GetVacuumTaskConfigForType returns vacuum configuration for a specific task type
func (mcm *MaintenanceConfigManager) GetVacuumTaskConfigForType(taskType string) *worker_pb.VacuumTaskConfig {
return GetVacuumTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
}
// GetErasureCodingTaskConfigForType returns erasure coding configuration for a specific task type
func (mcm *MaintenanceConfigManager) GetErasureCodingTaskConfigForType(taskType string) *worker_pb.ErasureCodingTaskConfig {
return GetErasureCodingTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
}
// GetBalanceTaskConfigForType returns balance configuration for a specific task type
func (mcm *MaintenanceConfigManager) GetBalanceTaskConfigForType(taskType string) *worker_pb.BalanceTaskConfig {
return GetBalanceTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
}
// GetReplicationTaskConfigForType returns replication configuration for a specific task type
func (mcm *MaintenanceConfigManager) GetReplicationTaskConfigForType(taskType string) *worker_pb.ReplicationTaskConfig {
return GetReplicationTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
}
// Helper methods
func (mcm *MaintenanceConfigManager) getTaskPolicy(taskType string) *worker_pb.TaskPolicy {
if mcm.config.Policy != nil && mcm.config.Policy.TaskPolicies != nil {
return mcm.config.Policy.TaskPolicies[taskType]
}
return nil
}
// IsTaskEnabled returns whether a task type is enabled
func (mcm *MaintenanceConfigManager) IsTaskEnabled(taskType string) bool {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
return policy.Enabled
}
return false
}
// GetMaxConcurrent returns the max concurrent limit for a task type
func (mcm *MaintenanceConfigManager) GetMaxConcurrent(taskType string) int32 {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
return policy.MaxConcurrent
}
return 1 // Default
}
// GetRepeatInterval returns the repeat interval for a task type in seconds
func (mcm *MaintenanceConfigManager) GetRepeatInterval(taskType string) int32 {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
return policy.RepeatIntervalSeconds
}
return mcm.config.Policy.DefaultRepeatIntervalSeconds
}
// GetCheckInterval returns the check interval for a task type in seconds
func (mcm *MaintenanceConfigManager) GetCheckInterval(taskType string) int32 {
if policy := mcm.getTaskPolicy(taskType); policy != nil {
return policy.CheckIntervalSeconds
}
return mcm.config.Policy.DefaultCheckIntervalSeconds
}
// Duration accessor methods
// GetScanInterval returns the scan interval as a time.Duration
func (mcm *MaintenanceConfigManager) GetScanInterval() time.Duration {
return time.Duration(mcm.config.ScanIntervalSeconds) * time.Second
}
// GetWorkerTimeout returns the worker timeout as a time.Duration
func (mcm *MaintenanceConfigManager) GetWorkerTimeout() time.Duration {
return time.Duration(mcm.config.WorkerTimeoutSeconds) * time.Second
}
// GetTaskTimeout returns the task timeout as a time.Duration
func (mcm *MaintenanceConfigManager) GetTaskTimeout() time.Duration {
return time.Duration(mcm.config.TaskTimeoutSeconds) * time.Second
}
// GetRetryDelay returns the retry delay as a time.Duration
func (mcm *MaintenanceConfigManager) GetRetryDelay() time.Duration {
return time.Duration(mcm.config.RetryDelaySeconds) * time.Second
}
// GetCleanupInterval returns the cleanup interval as a time.Duration
func (mcm *MaintenanceConfigManager) GetCleanupInterval() time.Duration {
return time.Duration(mcm.config.CleanupIntervalSeconds) * time.Second
}
// GetTaskRetention returns the task retention period as a time.Duration
func (mcm *MaintenanceConfigManager) GetTaskRetention() time.Duration {
return time.Duration(mcm.config.TaskRetentionSeconds) * time.Second
}
// ValidateMaintenanceConfigWithSchema validates protobuf maintenance configuration using ConfigField rules
func ValidateMaintenanceConfigWithSchema(config *worker_pb.MaintenanceConfig) error {
if config == nil {
return fmt.Errorf("configuration cannot be nil")
}
// Get the schema to access field validation rules
schema := GetMaintenanceConfigSchema()
// Validate each field individually using the ConfigField rules
if err := validateFieldWithSchema(schema, "enabled", config.Enabled); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "scan_interval_seconds", int(config.ScanIntervalSeconds)); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "worker_timeout_seconds", int(config.WorkerTimeoutSeconds)); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "task_timeout_seconds", int(config.TaskTimeoutSeconds)); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "retry_delay_seconds", int(config.RetryDelaySeconds)); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "max_retries", int(config.MaxRetries)); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "cleanup_interval_seconds", int(config.CleanupIntervalSeconds)); err != nil {
return err
}
if err := validateFieldWithSchema(schema, "task_retention_seconds", int(config.TaskRetentionSeconds)); err != nil {
return err
}
// Validate policy fields if present
if config.Policy != nil {
// Note: These field names might need to be adjusted based on the actual schema
if err := validatePolicyField("global_max_concurrent", int(config.Policy.GlobalMaxConcurrent)); err != nil {
return err
}
if err := validatePolicyField("default_repeat_interval_seconds", int(config.Policy.DefaultRepeatIntervalSeconds)); err != nil {
return err
}
if err := validatePolicyField("default_check_interval_seconds", int(config.Policy.DefaultCheckIntervalSeconds)); err != nil {
return err
}
}
return nil
}
// validateFieldWithSchema validates a single field using its ConfigField definition
func validateFieldWithSchema(schema *MaintenanceConfigSchema, fieldName string, value interface{}) error {
field := schema.GetFieldByName(fieldName)
if field == nil {
// Field not in schema, skip validation
return nil
}
return field.ValidateValue(value)
}
// validatePolicyField validates policy fields (simplified validation for now)
func validatePolicyField(fieldName string, value int) error {
switch fieldName {
case "global_max_concurrent":
if value < 1 || value > 20 {
return fmt.Errorf("Global Max Concurrent must be between 1 and 20, got %d", value)
}
case "default_repeat_interval":
if value < 1 || value > 168 {
return fmt.Errorf("Default Repeat Interval must be between 1 and 168 hours, got %d", value)
}
case "default_check_interval":
if value < 1 || value > 168 {
return fmt.Errorf("Default Check Interval must be between 1 and 168 hours, got %d", value)
}
}
return nil
}

View File

@@ -1055,28 +1055,6 @@ func (mq *MaintenanceQueue) getMaxConcurrentForTaskType(taskType MaintenanceTask
return 1 return 1
} }
// getRunningTasks returns all currently running tasks
func (mq *MaintenanceQueue) getRunningTasks() []*MaintenanceTask {
var runningTasks []*MaintenanceTask
for _, task := range mq.tasks {
if task.Status == TaskStatusAssigned || task.Status == TaskStatusInProgress {
runningTasks = append(runningTasks, task)
}
}
return runningTasks
}
// getAvailableWorkers returns all workers that can take more work
func (mq *MaintenanceQueue) getAvailableWorkers() []*MaintenanceWorker {
var availableWorkers []*MaintenanceWorker
for _, worker := range mq.workers {
if worker.Status == "active" && worker.CurrentLoad < worker.MaxConcurrent {
availableWorkers = append(availableWorkers, worker)
}
}
return availableWorkers
}
// trackPendingOperation adds a task to the pending operations tracker // trackPendingOperation adds a task to the pending operations tracker
func (mq *MaintenanceQueue) trackPendingOperation(task *MaintenanceTask) { func (mq *MaintenanceQueue) trackPendingOperation(task *MaintenanceTask) {
if mq.integration == nil { if mq.integration == nil {

View File

@@ -2,15 +2,11 @@ package maintenance
import ( import (
"html/template" "html/template"
"sort"
"sync" "sync"
"time" "time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/master_pb" "github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
"github.com/seaweedfs/seaweedfs/weed/worker/tasks"
"github.com/seaweedfs/seaweedfs/weed/worker/types"
) )
// AdminClient interface defines what the maintenance system needs from the admin server // AdminClient interface defines what the maintenance system needs from the admin server
@@ -21,51 +17,6 @@ type AdminClient interface {
// MaintenanceTaskType represents different types of maintenance operations // MaintenanceTaskType represents different types of maintenance operations
type MaintenanceTaskType string type MaintenanceTaskType string
// GetRegisteredMaintenanceTaskTypes returns all registered task types as MaintenanceTaskType values
// sorted alphabetically for consistent menu ordering
func GetRegisteredMaintenanceTaskTypes() []MaintenanceTaskType {
typesRegistry := tasks.GetGlobalTypesRegistry()
var taskTypes []MaintenanceTaskType
for workerTaskType := range typesRegistry.GetAllDetectors() {
maintenanceTaskType := MaintenanceTaskType(string(workerTaskType))
taskTypes = append(taskTypes, maintenanceTaskType)
}
// Sort task types alphabetically to ensure consistent menu ordering
sort.Slice(taskTypes, func(i, j int) bool {
return string(taskTypes[i]) < string(taskTypes[j])
})
return taskTypes
}
// GetMaintenanceTaskType returns a specific task type if it's registered, or empty string if not found
func GetMaintenanceTaskType(taskTypeName string) MaintenanceTaskType {
typesRegistry := tasks.GetGlobalTypesRegistry()
for workerTaskType := range typesRegistry.GetAllDetectors() {
if string(workerTaskType) == taskTypeName {
return MaintenanceTaskType(taskTypeName)
}
}
return MaintenanceTaskType("")
}
// IsMaintenanceTaskTypeRegistered checks if a task type is registered
func IsMaintenanceTaskTypeRegistered(taskType MaintenanceTaskType) bool {
typesRegistry := tasks.GetGlobalTypesRegistry()
for workerTaskType := range typesRegistry.GetAllDetectors() {
if string(workerTaskType) == string(taskType) {
return true
}
}
return false
}
// MaintenanceTaskPriority represents task execution priority // MaintenanceTaskPriority represents task execution priority
type MaintenanceTaskPriority int type MaintenanceTaskPriority int
@@ -200,14 +151,6 @@ func GetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType) *TaskPol
return mp.TaskPolicies[string(taskType)] return mp.TaskPolicies[string(taskType)]
} }
// SetTaskPolicy sets the policy for a specific task type
func SetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType, policy *TaskPolicy) {
if mp.TaskPolicies == nil {
mp.TaskPolicies = make(map[string]*TaskPolicy)
}
mp.TaskPolicies[string(taskType)] = policy
}
// IsTaskEnabled returns whether a task type is enabled // IsTaskEnabled returns whether a task type is enabled
func IsTaskEnabled(mp *MaintenancePolicy, taskType MaintenanceTaskType) bool { func IsTaskEnabled(mp *MaintenancePolicy, taskType MaintenanceTaskType) bool {
policy := GetTaskPolicy(mp, taskType) policy := GetTaskPolicy(mp, taskType)
@@ -235,84 +178,6 @@ func GetRepeatInterval(mp *MaintenancePolicy, taskType MaintenanceTaskType) int
return int(policy.RepeatIntervalSeconds) return int(policy.RepeatIntervalSeconds)
} }
// GetVacuumTaskConfig returns the vacuum task configuration
func GetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.VacuumTaskConfig {
policy := GetTaskPolicy(mp, taskType)
if policy == nil {
return nil
}
return policy.GetVacuumConfig()
}
// GetErasureCodingTaskConfig returns the erasure coding task configuration
func GetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ErasureCodingTaskConfig {
policy := GetTaskPolicy(mp, taskType)
if policy == nil {
return nil
}
return policy.GetErasureCodingConfig()
}
// GetBalanceTaskConfig returns the balance task configuration
func GetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.BalanceTaskConfig {
policy := GetTaskPolicy(mp, taskType)
if policy == nil {
return nil
}
return policy.GetBalanceConfig()
}
// GetReplicationTaskConfig returns the replication task configuration
func GetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ReplicationTaskConfig {
policy := GetTaskPolicy(mp, taskType)
if policy == nil {
return nil
}
return policy.GetReplicationConfig()
}
// Note: GetTaskConfig was removed - use typed getters: GetVacuumTaskConfig, GetErasureCodingTaskConfig, GetBalanceTaskConfig, or GetReplicationTaskConfig
// SetVacuumTaskConfig sets the vacuum task configuration
func SetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.VacuumTaskConfig) {
policy := GetTaskPolicy(mp, taskType)
if policy != nil {
policy.TaskConfig = &worker_pb.TaskPolicy_VacuumConfig{
VacuumConfig: config,
}
}
}
// SetErasureCodingTaskConfig sets the erasure coding task configuration
func SetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ErasureCodingTaskConfig) {
policy := GetTaskPolicy(mp, taskType)
if policy != nil {
policy.TaskConfig = &worker_pb.TaskPolicy_ErasureCodingConfig{
ErasureCodingConfig: config,
}
}
}
// SetBalanceTaskConfig sets the balance task configuration
func SetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.BalanceTaskConfig) {
policy := GetTaskPolicy(mp, taskType)
if policy != nil {
policy.TaskConfig = &worker_pb.TaskPolicy_BalanceConfig{
BalanceConfig: config,
}
}
}
// SetReplicationTaskConfig sets the replication task configuration
func SetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ReplicationTaskConfig) {
policy := GetTaskPolicy(mp, taskType)
if policy != nil {
policy.TaskConfig = &worker_pb.TaskPolicy_ReplicationConfig{
ReplicationConfig: config,
}
}
}
// SetTaskConfig sets a configuration value for a task type (legacy method - use typed setters above) // SetTaskConfig sets a configuration value for a task type (legacy method - use typed setters above)
// Note: SetTaskConfig was removed - use typed setters: SetVacuumTaskConfig, SetErasureCodingTaskConfig, SetBalanceTaskConfig, or SetReplicationTaskConfig // Note: SetTaskConfig was removed - use typed setters: SetVacuumTaskConfig, SetErasureCodingTaskConfig, SetBalanceTaskConfig, or SetReplicationTaskConfig
@@ -475,180 +340,6 @@ type ClusterReplicationTask struct {
Metadata map[string]string `json:"metadata,omitempty"` Metadata map[string]string `json:"metadata,omitempty"`
} }
// BuildMaintenancePolicyFromTasks creates a maintenance policy with configurations
// from all registered tasks using their UI providers
func BuildMaintenancePolicyFromTasks() *MaintenancePolicy {
policy := &MaintenancePolicy{
TaskPolicies: make(map[string]*TaskPolicy),
GlobalMaxConcurrent: 4,
DefaultRepeatIntervalSeconds: 6 * 3600, // 6 hours in seconds
DefaultCheckIntervalSeconds: 12 * 3600, // 12 hours in seconds
}
// Get all registered task types from the UI registry
uiRegistry := tasks.GetGlobalUIRegistry()
typesRegistry := tasks.GetGlobalTypesRegistry()
for taskType, provider := range uiRegistry.GetAllProviders() {
// Convert task type to maintenance task type
maintenanceTaskType := MaintenanceTaskType(string(taskType))
// Get the default configuration from the UI provider
defaultConfig := provider.GetCurrentConfig()
// Create task policy from UI configuration
taskPolicy := &TaskPolicy{
Enabled: true, // Default enabled
MaxConcurrent: 2, // Default concurrency
RepeatIntervalSeconds: policy.DefaultRepeatIntervalSeconds,
CheckIntervalSeconds: policy.DefaultCheckIntervalSeconds,
}
// Extract configuration using TaskConfig interface - no more map conversions!
if taskConfig, ok := defaultConfig.(interface{ ToTaskPolicy() *worker_pb.TaskPolicy }); ok {
// Use protobuf directly for clean, type-safe config extraction
pbTaskPolicy := taskConfig.ToTaskPolicy()
taskPolicy.Enabled = pbTaskPolicy.Enabled
taskPolicy.MaxConcurrent = pbTaskPolicy.MaxConcurrent
if pbTaskPolicy.RepeatIntervalSeconds > 0 {
taskPolicy.RepeatIntervalSeconds = pbTaskPolicy.RepeatIntervalSeconds
}
if pbTaskPolicy.CheckIntervalSeconds > 0 {
taskPolicy.CheckIntervalSeconds = pbTaskPolicy.CheckIntervalSeconds
}
}
// Also get defaults from scheduler if available (using types.TaskScheduler explicitly)
var scheduler types.TaskScheduler = typesRegistry.GetScheduler(taskType)
if scheduler != nil {
if taskPolicy.MaxConcurrent <= 0 {
taskPolicy.MaxConcurrent = int32(scheduler.GetMaxConcurrent())
}
// Convert default repeat interval to seconds
if repeatInterval := scheduler.GetDefaultRepeatInterval(); repeatInterval > 0 {
taskPolicy.RepeatIntervalSeconds = int32(repeatInterval.Seconds())
}
}
// Also get defaults from detector if available (using types.TaskDetector explicitly)
var detector types.TaskDetector = typesRegistry.GetDetector(taskType)
if detector != nil {
// Convert scan interval to check interval (seconds)
if scanInterval := detector.ScanInterval(); scanInterval > 0 {
taskPolicy.CheckIntervalSeconds = int32(scanInterval.Seconds())
}
}
policy.TaskPolicies[string(maintenanceTaskType)] = taskPolicy
glog.V(3).Infof("Built policy for task type %s: enabled=%v, max_concurrent=%d",
maintenanceTaskType, taskPolicy.Enabled, taskPolicy.MaxConcurrent)
}
glog.V(2).Infof("Built maintenance policy with %d task configurations", len(policy.TaskPolicies))
return policy
}
// SetPolicyFromTasks sets the maintenance policy from registered tasks
func SetPolicyFromTasks(policy *MaintenancePolicy) {
if policy == nil {
return
}
// Build new policy from tasks
newPolicy := BuildMaintenancePolicyFromTasks()
// Copy task policies
policy.TaskPolicies = newPolicy.TaskPolicies
glog.V(1).Infof("Updated maintenance policy with %d task configurations from registered tasks", len(policy.TaskPolicies))
}
// GetTaskIcon returns the icon CSS class for a task type from its UI provider
func GetTaskIcon(taskType MaintenanceTaskType) string {
typesRegistry := tasks.GetGlobalTypesRegistry()
uiRegistry := tasks.GetGlobalUIRegistry()
// Convert MaintenanceTaskType to TaskType
for workerTaskType := range typesRegistry.GetAllDetectors() {
if string(workerTaskType) == string(taskType) {
// Get the UI provider for this task type
provider := uiRegistry.GetProvider(workerTaskType)
if provider != nil {
return provider.GetIcon()
}
break
}
}
// Default icon if no UI provider found
return "fas fa-cog text-muted"
}
// GetTaskDisplayName returns the display name for a task type from its UI provider
func GetTaskDisplayName(taskType MaintenanceTaskType) string {
typesRegistry := tasks.GetGlobalTypesRegistry()
uiRegistry := tasks.GetGlobalUIRegistry()
// Convert MaintenanceTaskType to TaskType
for workerTaskType := range typesRegistry.GetAllDetectors() {
if string(workerTaskType) == string(taskType) {
// Get the UI provider for this task type
provider := uiRegistry.GetProvider(workerTaskType)
if provider != nil {
return provider.GetDisplayName()
}
break
}
}
// Fallback to the task type string
return string(taskType)
}
// GetTaskDescription returns the description for a task type from its UI provider
func GetTaskDescription(taskType MaintenanceTaskType) string {
typesRegistry := tasks.GetGlobalTypesRegistry()
uiRegistry := tasks.GetGlobalUIRegistry()
// Convert MaintenanceTaskType to TaskType
for workerTaskType := range typesRegistry.GetAllDetectors() {
if string(workerTaskType) == string(taskType) {
// Get the UI provider for this task type
provider := uiRegistry.GetProvider(workerTaskType)
if provider != nil {
return provider.GetDescription()
}
break
}
}
// Fallback to a generic description
return "Configure detailed settings for " + string(taskType) + " tasks."
}
// BuildMaintenanceMenuItems creates menu items for all registered task types
func BuildMaintenanceMenuItems() []*MaintenanceMenuItem {
var menuItems []*MaintenanceMenuItem
// Get all registered task types
registeredTypes := GetRegisteredMaintenanceTaskTypes()
for _, taskType := range registeredTypes {
menuItem := &MaintenanceMenuItem{
TaskType: taskType,
DisplayName: GetTaskDisplayName(taskType),
Description: GetTaskDescription(taskType),
Icon: GetTaskIcon(taskType),
IsEnabled: IsMaintenanceTaskTypeRegistered(taskType),
Path: "/maintenance/config/" + string(taskType),
}
menuItems = append(menuItems, menuItem)
}
return menuItems
}
// Helper functions to extract configuration fields // Helper functions to extract configuration fields
// Note: Removed getVacuumConfigField, getErasureCodingConfigField, getBalanceConfigField, getReplicationConfigField // Note: Removed getVacuumConfigField, getErasureCodingConfigField, getBalanceConfigField, getReplicationConfigField

View File

@@ -1,421 +0,0 @@
package maintenance
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/worker"
"github.com/seaweedfs/seaweedfs/weed/worker/tasks"
"github.com/seaweedfs/seaweedfs/weed/worker/types"
// Import task packages to trigger their auto-registration
_ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/balance"
_ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/erasure_coding"
_ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/vacuum"
)
// MaintenanceWorkerService manages maintenance task execution
// TaskExecutor defines the function signature for task execution
type TaskExecutor func(*MaintenanceWorkerService, *MaintenanceTask) error
// TaskExecutorFactory creates a task executor for a given worker service
type TaskExecutorFactory func() TaskExecutor
// Global registry for task executor factories
var taskExecutorFactories = make(map[MaintenanceTaskType]TaskExecutorFactory)
var executorRegistryMutex sync.RWMutex
var executorRegistryInitOnce sync.Once
// initializeExecutorFactories dynamically registers executor factories for all auto-registered task types
func initializeExecutorFactories() {
executorRegistryInitOnce.Do(func() {
// Get all registered task types from the global registry
typesRegistry := tasks.GetGlobalTypesRegistry()
var taskTypes []MaintenanceTaskType
for workerTaskType := range typesRegistry.GetAllDetectors() {
// Convert types.TaskType to MaintenanceTaskType by string conversion
maintenanceTaskType := MaintenanceTaskType(string(workerTaskType))
taskTypes = append(taskTypes, maintenanceTaskType)
}
// Register generic executor for all task types
for _, taskType := range taskTypes {
RegisterTaskExecutorFactory(taskType, createGenericTaskExecutor)
}
glog.V(1).Infof("Dynamically registered generic task executor for %d task types: %v", len(taskTypes), taskTypes)
})
}
// RegisterTaskExecutorFactory registers a factory function for creating task executors
func RegisterTaskExecutorFactory(taskType MaintenanceTaskType, factory TaskExecutorFactory) {
executorRegistryMutex.Lock()
defer executorRegistryMutex.Unlock()
taskExecutorFactories[taskType] = factory
glog.V(2).Infof("Registered executor factory for task type: %s", taskType)
}
// GetTaskExecutorFactory returns the factory for a task type
func GetTaskExecutorFactory(taskType MaintenanceTaskType) (TaskExecutorFactory, bool) {
// Ensure executor factories are initialized
initializeExecutorFactories()
executorRegistryMutex.RLock()
defer executorRegistryMutex.RUnlock()
factory, exists := taskExecutorFactories[taskType]
return factory, exists
}
// GetSupportedExecutorTaskTypes returns all task types with registered executor factories
func GetSupportedExecutorTaskTypes() []MaintenanceTaskType {
// Ensure executor factories are initialized
initializeExecutorFactories()
executorRegistryMutex.RLock()
defer executorRegistryMutex.RUnlock()
taskTypes := make([]MaintenanceTaskType, 0, len(taskExecutorFactories))
for taskType := range taskExecutorFactories {
taskTypes = append(taskTypes, taskType)
}
return taskTypes
}
// createGenericTaskExecutor creates a generic task executor that uses the task registry
func createGenericTaskExecutor() TaskExecutor {
return func(mws *MaintenanceWorkerService, task *MaintenanceTask) error {
return mws.executeGenericTask(task)
}
}
// init does minimal initialization - actual registration happens lazily
func init() {
// Executor factory registration will happen lazily when first accessed
glog.V(1).Infof("Maintenance worker initialized - executor factories will be registered on first access")
}
type MaintenanceWorkerService struct {
workerID string
address string
adminServer string
capabilities []MaintenanceTaskType
maxConcurrent int
currentTasks map[string]*MaintenanceTask
queue *MaintenanceQueue
adminClient AdminClient
running bool
stopChan chan struct{}
// Task execution registry
taskExecutors map[MaintenanceTaskType]TaskExecutor
// Task registry for creating task instances
taskRegistry *tasks.TaskRegistry
}
// NewMaintenanceWorkerService creates a new maintenance worker service
func NewMaintenanceWorkerService(workerID, address, adminServer string) *MaintenanceWorkerService {
// Get all registered maintenance task types dynamically
capabilities := GetRegisteredMaintenanceTaskTypes()
worker := &MaintenanceWorkerService{
workerID: workerID,
address: address,
adminServer: adminServer,
capabilities: capabilities,
maxConcurrent: 2, // Default concurrent task limit
currentTasks: make(map[string]*MaintenanceTask),
stopChan: make(chan struct{}),
taskExecutors: make(map[MaintenanceTaskType]TaskExecutor),
taskRegistry: tasks.GetGlobalTaskRegistry(), // Use global registry with auto-registered tasks
}
// Initialize task executor registry
worker.initializeTaskExecutors()
glog.V(1).Infof("Created maintenance worker with %d registered task types", len(worker.taskRegistry.GetAll()))
return worker
}
// executeGenericTask executes a task using the task registry instead of hardcoded methods
func (mws *MaintenanceWorkerService) executeGenericTask(task *MaintenanceTask) error {
glog.V(2).Infof("Executing generic task %s: %s for volume %d", task.ID, task.Type, task.VolumeID)
// Validate that task has proper typed parameters
if task.TypedParams == nil {
return fmt.Errorf("task %s has no typed parameters - task was not properly planned (insufficient destinations)", task.ID)
}
// Convert MaintenanceTask to types.TaskType
taskType := types.TaskType(string(task.Type))
// Create task instance using the registry
taskInstance, err := mws.taskRegistry.Get(taskType).Create(task.TypedParams)
if err != nil {
return fmt.Errorf("failed to create task instance: %w", err)
}
// Update progress to show task has started
mws.updateTaskProgress(task.ID, 5)
// Execute the task
err = taskInstance.Execute(context.Background(), task.TypedParams)
if err != nil {
return fmt.Errorf("task execution failed: %w", err)
}
// Update progress to show completion
mws.updateTaskProgress(task.ID, 100)
glog.V(2).Infof("Generic task %s completed successfully", task.ID)
return nil
}
// initializeTaskExecutors sets up the task execution registry dynamically
func (mws *MaintenanceWorkerService) initializeTaskExecutors() {
mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor)
// Get all registered executor factories and create executors
executorRegistryMutex.RLock()
defer executorRegistryMutex.RUnlock()
for taskType, factory := range taskExecutorFactories {
executor := factory()
mws.taskExecutors[taskType] = executor
glog.V(3).Infof("Initialized executor for task type: %s", taskType)
}
glog.V(2).Infof("Initialized %d task executors", len(mws.taskExecutors))
}
// RegisterTaskExecutor allows dynamic registration of new task executors
func (mws *MaintenanceWorkerService) RegisterTaskExecutor(taskType MaintenanceTaskType, executor TaskExecutor) {
if mws.taskExecutors == nil {
mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor)
}
mws.taskExecutors[taskType] = executor
glog.V(1).Infof("Registered executor for task type: %s", taskType)
}
// GetSupportedTaskTypes returns all task types that this worker can execute
func (mws *MaintenanceWorkerService) GetSupportedTaskTypes() []MaintenanceTaskType {
return GetSupportedExecutorTaskTypes()
}
// Start begins the worker service
func (mws *MaintenanceWorkerService) Start() error {
mws.running = true
// Register with admin server
worker := &MaintenanceWorker{
ID: mws.workerID,
Address: mws.address,
Capabilities: mws.capabilities,
MaxConcurrent: mws.maxConcurrent,
}
if mws.queue != nil {
mws.queue.RegisterWorker(worker)
}
// Start worker loop
go mws.workerLoop()
glog.Infof("Maintenance worker %s started at %s", mws.workerID, mws.address)
return nil
}
// Stop terminates the worker service
func (mws *MaintenanceWorkerService) Stop() {
mws.running = false
close(mws.stopChan)
// Wait for current tasks to complete or timeout
timeout := time.NewTimer(30 * time.Second)
defer timeout.Stop()
for len(mws.currentTasks) > 0 {
select {
case <-timeout.C:
glog.Warningf("Worker %s stopping with %d tasks still running", mws.workerID, len(mws.currentTasks))
return
case <-time.After(time.Second):
// Check again
}
}
glog.Infof("Maintenance worker %s stopped", mws.workerID)
}
// workerLoop is the main worker event loop
func (mws *MaintenanceWorkerService) workerLoop() {
heartbeatTicker := time.NewTicker(30 * time.Second)
defer heartbeatTicker.Stop()
taskRequestTicker := time.NewTicker(5 * time.Second)
defer taskRequestTicker.Stop()
for mws.running {
select {
case <-mws.stopChan:
return
case <-heartbeatTicker.C:
mws.sendHeartbeat()
case <-taskRequestTicker.C:
mws.requestTasks()
}
}
}
// sendHeartbeat sends heartbeat to admin server
func (mws *MaintenanceWorkerService) sendHeartbeat() {
if mws.queue != nil {
mws.queue.UpdateWorkerHeartbeat(mws.workerID)
}
}
// requestTasks requests new tasks from the admin server
func (mws *MaintenanceWorkerService) requestTasks() {
if len(mws.currentTasks) >= mws.maxConcurrent {
return // Already at capacity
}
if mws.queue != nil {
task := mws.queue.GetNextTask(mws.workerID, mws.capabilities)
if task != nil {
mws.executeTask(task)
}
}
}
// executeTask executes a maintenance task
func (mws *MaintenanceWorkerService) executeTask(task *MaintenanceTask) {
mws.currentTasks[task.ID] = task
go func() {
defer func() {
delete(mws.currentTasks, task.ID)
}()
glog.Infof("Worker %s executing task %s: %s", mws.workerID, task.ID, task.Type)
// Execute task using dynamic executor registry
var err error
if executor, exists := mws.taskExecutors[task.Type]; exists {
err = executor(mws, task)
} else {
err = fmt.Errorf("unsupported task type: %s", task.Type)
glog.Errorf("No executor registered for task type: %s", task.Type)
}
// Report task completion
if mws.queue != nil {
errorMsg := ""
if err != nil {
errorMsg = err.Error()
}
mws.queue.CompleteTask(task.ID, errorMsg)
}
if err != nil {
glog.Errorf("Worker %s failed to execute task %s: %v", mws.workerID, task.ID, err)
} else {
glog.Infof("Worker %s completed task %s successfully", mws.workerID, task.ID)
}
}()
}
// updateTaskProgress updates the progress of a task
func (mws *MaintenanceWorkerService) updateTaskProgress(taskID string, progress float64) {
if mws.queue != nil {
mws.queue.UpdateTaskProgress(taskID, progress)
}
}
// GetStatus returns the current status of the worker
func (mws *MaintenanceWorkerService) GetStatus() map[string]interface{} {
return map[string]interface{}{
"worker_id": mws.workerID,
"address": mws.address,
"running": mws.running,
"capabilities": mws.capabilities,
"max_concurrent": mws.maxConcurrent,
"current_tasks": len(mws.currentTasks),
"task_details": mws.currentTasks,
}
}
// SetQueue sets the maintenance queue for the worker
func (mws *MaintenanceWorkerService) SetQueue(queue *MaintenanceQueue) {
mws.queue = queue
}
// SetAdminClient sets the admin client for the worker
func (mws *MaintenanceWorkerService) SetAdminClient(client AdminClient) {
mws.adminClient = client
}
// SetCapabilities sets the worker capabilities
func (mws *MaintenanceWorkerService) SetCapabilities(capabilities []MaintenanceTaskType) {
mws.capabilities = capabilities
}
// SetMaxConcurrent sets the maximum concurrent tasks
func (mws *MaintenanceWorkerService) SetMaxConcurrent(max int) {
mws.maxConcurrent = max
}
// SetHeartbeatInterval sets the heartbeat interval (placeholder for future use)
func (mws *MaintenanceWorkerService) SetHeartbeatInterval(interval time.Duration) {
// Future implementation for configurable heartbeat
}
// SetTaskRequestInterval sets the task request interval (placeholder for future use)
func (mws *MaintenanceWorkerService) SetTaskRequestInterval(interval time.Duration) {
// Future implementation for configurable task requests
}
// MaintenanceWorkerCommand represents a standalone maintenance worker command
type MaintenanceWorkerCommand struct {
workerService *MaintenanceWorkerService
}
// NewMaintenanceWorkerCommand creates a new worker command
func NewMaintenanceWorkerCommand(workerID, address, adminServer string) *MaintenanceWorkerCommand {
return &MaintenanceWorkerCommand{
workerService: NewMaintenanceWorkerService(workerID, address, adminServer),
}
}
// Run starts the maintenance worker as a standalone service
func (mwc *MaintenanceWorkerCommand) Run() error {
// Generate or load persistent worker ID if not provided
if mwc.workerService.workerID == "" {
// Get current working directory for worker ID persistence
wd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get working directory: %w", err)
}
workerID, err := worker.GenerateOrLoadWorkerID(wd)
if err != nil {
return fmt.Errorf("failed to generate or load worker ID: %w", err)
}
mwc.workerService.workerID = workerID
}
// Start the worker service
err := mwc.workerService.Start()
if err != nil {
return fmt.Errorf("failed to start maintenance worker: %w", err)
}
// Wait for interrupt signal
select {}
}

View File

@@ -122,6 +122,7 @@ type Plugin struct {
type streamSession struct { type streamSession struct {
workerID string workerID string
outgoing chan *plugin_pb.AdminToWorkerMessage outgoing chan *plugin_pb.AdminToWorkerMessage
done chan struct{}
closeOnce sync.Once closeOnce sync.Once
} }
@@ -274,6 +275,7 @@ func (r *Plugin) WorkerStream(stream plugin_pb.PluginControlService_WorkerStream
session := &streamSession{ session := &streamSession{
workerID: workerID, workerID: workerID,
outgoing: make(chan *plugin_pb.AdminToWorkerMessage, r.outgoingBuffer), outgoing: make(chan *plugin_pb.AdminToWorkerMessage, r.outgoingBuffer),
done: make(chan struct{}),
} }
r.putSession(session) r.putSession(session)
defer r.cleanupSession(workerID) defer r.cleanupSession(workerID)
@@ -908,8 +910,10 @@ func (r *Plugin) sendLoop(
return nil return nil
case <-r.shutdownCh: case <-r.shutdownCh:
return nil return nil
case msg, ok := <-session.outgoing: case <-session.done:
if !ok { return nil
case msg := <-session.outgoing:
if msg == nil {
return nil return nil
} }
if err := stream.Send(msg); err != nil { if err := stream.Send(msg); err != nil {
@@ -930,6 +934,8 @@ func (r *Plugin) sendToWorker(workerID string, message *plugin_pb.AdminToWorkerM
select { select {
case <-r.shutdownCh: case <-r.shutdownCh:
return fmt.Errorf("plugin is shutting down") return fmt.Errorf("plugin is shutting down")
case <-session.done:
return fmt.Errorf("worker %s session is closed", workerID)
case session.outgoing <- message: case session.outgoing <- message:
return nil return nil
case <-time.After(r.sendTimeout): case <-time.After(r.sendTimeout):
@@ -1425,7 +1431,7 @@ func CloneConfigValueMap(in map[string]*plugin_pb.ConfigValue) map[string]*plugi
func (s *streamSession) close() { func (s *streamSession) close() {
s.closeOnce.Do(func() { s.closeOnce.Do(func() {
close(s.outgoing) close(s.done)
}) })
} }

View File

@@ -26,7 +26,7 @@ func TestRunDetectionSendsCancelOnContextDone(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
}, },
}) })
session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)} session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})}
pluginSvc.putSession(session) pluginSvc.putSession(session)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -77,7 +77,7 @@ func TestExecuteJobSendsCancelOnContextDone(t *testing.T) {
{JobType: jobType, CanExecute: true, MaxExecutionConcurrency: 1}, {JobType: jobType, CanExecute: true, MaxExecutionConcurrency: 1},
}, },
}) })
session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)} session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})}
pluginSvc.putSession(session) pluginSvc.putSession(session)
job := &plugin_pb.JobSpec{JobId: "job-1", JobType: jobType} job := &plugin_pb.JobSpec{JobId: "job-1", JobType: jobType}
@@ -135,8 +135,8 @@ func TestAdminScriptExecutionBlocksOtherDetection(t *testing.T) {
{JobType: "vacuum", CanDetect: true, MaxDetectionConcurrency: 1}, {JobType: "vacuum", CanDetect: true, MaxDetectionConcurrency: 1},
}, },
}) })
adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
pluginSvc.putSession(adminSession) pluginSvc.putSession(adminSession)
pluginSvc.putSession(otherSession) pluginSvc.putSession(otherSession)
@@ -214,8 +214,8 @@ func TestAdminScriptExecutionBlocksOtherExecution(t *testing.T) {
{JobType: "vacuum", CanExecute: true, MaxExecutionConcurrency: 1}, {JobType: "vacuum", CanExecute: true, MaxExecutionConcurrency: 1},
}, },
}) })
adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
pluginSvc.putSession(adminSession) pluginSvc.putSession(adminSession)
pluginSvc.putSession(otherSession) pluginSvc.putSession(otherSession)

View File

@@ -22,7 +22,7 @@ func TestRunDetectionIncludesLatestSuccessfulRun(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
}, },
}) })
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session) pluginSvc.putSession(session)
oldSuccess := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) oldSuccess := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
@@ -80,7 +80,7 @@ func TestRunDetectionOmitsLastSuccessfulRunWhenNoSuccessHistory(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
}, },
}) })
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session) pluginSvc.putSession(session)
if err := pluginSvc.store.AppendRunRecord(jobType, &JobRunRecord{ if err := pluginSvc.store.AppendRunRecord(jobType, &JobRunRecord{
@@ -130,7 +130,7 @@ func TestRunDetectionWithReportCapturesDetectionActivities(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
}, },
}) })
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session) pluginSvc.putSession(session)
reportCh := make(chan *DetectionReport, 1) reportCh := make(chan *DetectionReport, 1)
@@ -210,7 +210,7 @@ func TestRunDetectionAdminScriptUsesLastCompletedRun(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
}, },
}) })
session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session) pluginSvc.putSession(session)
successCompleted := time.Date(2026, 2, 1, 10, 0, 0, 0, time.UTC) successCompleted := time.Date(2026, 2, 1, 10, 0, 0, 0, time.UTC)

View File

@@ -95,16 +95,6 @@ func (r *Plugin) laneSchedulerLoop(ls *schedulerLaneState) {
} }
} }
// schedulerLoop is kept for backward compatibility; it delegates to
// laneSchedulerLoop with the default lane. New code should not call this.
func (r *Plugin) schedulerLoop() {
ls := r.lanes[LaneDefault]
if ls == nil {
ls = newLaneState(LaneDefault)
}
r.laneSchedulerLoop(ls)
}
// runLaneSchedulerIteration runs one scheduling pass for a single lane, // runLaneSchedulerIteration runs one scheduling pass for a single lane,
// processing only the job types assigned to that lane. // processing only the job types assigned to that lane.
// //
@@ -229,82 +219,6 @@ func (r *Plugin) runLaneSchedulerIterationConcurrent(ls *schedulerLaneState, job
return hadJobs.Load() return hadJobs.Load()
} }
// runSchedulerIteration is kept for backward compatibility. It runs a
// single iteration across ALL job types (equivalent to the old single-loop
// behavior). It is only used by the legacy schedulerLoop() fallback.
func (r *Plugin) runSchedulerIteration() bool {
ls := r.lanes[LaneDefault]
if ls == nil {
ls = newLaneState(LaneDefault)
}
// For backward compat, the old function processes all job types.
r.expireStaleJobs(time.Now().UTC())
jobTypes := r.registry.DetectableJobTypes()
if len(jobTypes) == 0 {
r.setSchedulerLoopState("", "idle")
return false
}
r.setSchedulerLoopState("", "waiting_for_lock")
releaseLock, err := r.acquireAdminLock("plugin scheduler iteration")
if err != nil {
glog.Warningf("Plugin scheduler failed to acquire lock: %v", err)
r.setSchedulerLoopState("", "idle")
return false
}
if releaseLock != nil {
defer releaseLock()
}
active := make(map[string]struct{}, len(jobTypes))
hadJobs := false
for _, jobType := range jobTypes {
active[jobType] = struct{}{}
policy, enabled, err := r.loadSchedulerPolicy(jobType)
if err != nil {
glog.Warningf("Plugin scheduler failed to load policy for %s: %v", jobType, err)
continue
}
if !enabled {
r.clearSchedulerJobType(jobType)
continue
}
initialDelay := time.Duration(0)
if runInfo := r.snapshotSchedulerRun(jobType); runInfo.lastRunStartedAt.IsZero() {
initialDelay = 5 * time.Second
}
if !r.markDetectionDue(jobType, policy.DetectionInterval, initialDelay) {
continue
}
detected := r.runJobTypeIteration(jobType, policy)
if detected {
hadJobs = true
}
}
r.pruneSchedulerState(active)
r.pruneDetectorLeases(active)
r.setSchedulerLoopState("", "idle")
return hadJobs
}
// wakeLane wakes the scheduler goroutine for a specific lane.
func (r *Plugin) wakeLane(lane SchedulerLane) {
if r == nil {
return
}
if ls, ok := r.lanes[lane]; ok {
select {
case ls.wakeCh <- struct{}{}:
default:
}
}
}
// wakeAllLanes wakes all lane scheduler goroutines. // wakeAllLanes wakes all lane scheduler goroutines.
func (r *Plugin) wakeAllLanes() { func (r *Plugin) wakeAllLanes() {
if r == nil { if r == nil {

View File

@@ -210,16 +210,6 @@ func (r *Plugin) setSchedulerLoopStateForJobType(jobType, phase string) {
} }
} }
func (r *Plugin) recordSchedulerIterationComplete(hadJobs bool) {
if r == nil {
return
}
r.schedulerLoopMu.Lock()
r.schedulerLoopState.lastIterationHadJobs = hadJobs
r.schedulerLoopState.lastIterationCompleted = time.Now().UTC()
r.schedulerLoopMu.Unlock()
}
func (r *Plugin) snapshotSchedulerLoopState() schedulerLoopState { func (r *Plugin) snapshotSchedulerLoopState() schedulerLoopState {
if r == nil { if r == nil {
return schedulerLoopState{} return schedulerLoopState{}

View File

@@ -6,20 +6,6 @@ import (
"strings" "strings"
) )
// getStatusColor returns Bootstrap color class for status
func getStatusColor(status string) string {
switch status {
case "active", "healthy":
return "success"
case "warning":
return "warning"
case "critical", "unreachable":
return "danger"
default:
return "secondary"
}
}
// formatBytes converts bytes to human readable format // formatBytes converts bytes to human readable format
func formatBytes(bytes int64) string { func formatBytes(bytes int64) string {
if bytes == 0 { if bytes == 0 {

View File

@@ -95,18 +95,6 @@ func NewCluster() *Cluster {
} }
} }
func (cluster *Cluster) getGroupMembers(filerGroup FilerGroupName, nodeType string, createIfNotFound bool) *GroupMembers {
switch nodeType {
case FilerType:
return cluster.filerGroups.getGroupMembers(filerGroup, createIfNotFound)
case BrokerType:
return cluster.brokerGroups.getGroupMembers(filerGroup, createIfNotFound)
case S3Type:
return cluster.s3Groups.getGroupMembers(filerGroup, createIfNotFound)
}
return nil
}
func (cluster *Cluster) AddClusterNode(ns, nodeType string, dataCenter DataCenter, rack Rack, address pb.ServerAddress, version string) []*master_pb.KeepConnectedResponse { func (cluster *Cluster) AddClusterNode(ns, nodeType string, dataCenter DataCenter, rack Rack, address pb.ServerAddress, version string) []*master_pb.KeepConnectedResponse {
filerGroup := FilerGroupName(ns) filerGroup := FilerGroupName(ns)
switch nodeType { switch nodeType {

View File

@@ -511,11 +511,6 @@ func recoveryMiddleware(next http.Handler) http.Handler {
}) })
} }
// GetAdminOptions returns the admin command options for testing
func GetAdminOptions() *AdminOptions {
return &AdminOptions{}
}
// loadOrGenerateSessionKeys loads or creates authentication/encryption keys for session cookies. // loadOrGenerateSessionKeys loads or creates authentication/encryption keys for session cookies.
func loadOrGenerateSessionKeys(dataDir string) ([]byte, []byte, error) { func loadOrGenerateSessionKeys(dataDir string) ([]byte, []byte, error) {
const keyLen = 32 const keyLen = 32

View File

@@ -132,16 +132,3 @@ func fetchContent(masterFn operation.GetMasterFn, grpcDialOption grpc.DialOption
content, e = io.ReadAll(rc.Body) content, e = io.ReadAll(rc.Body)
return return
} }
func WriteFile(filename string, data []byte, perm os.FileMode) error {
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm)
if err != nil {
return err
}
n, err := f.Write(data)
f.Close()
if err == nil && n < len(data) {
err = io.ErrShortWrite
}
return err
}

View File

@@ -57,42 +57,6 @@ func LoadCredentialConfiguration() (*CredentialConfig, error) {
}, nil }, nil
} }
// GetCredentialStoreConfig extracts credential store configuration from command line flags
// This is used when credential store is configured via command line instead of credential.toml
func GetCredentialStoreConfig(store string, config util.Configuration, prefix string) *CredentialConfig {
if store == "" {
return nil
}
return &CredentialConfig{
Store: store,
Config: config,
Prefix: prefix,
}
}
// MergeCredentialConfig merges command line credential config with credential.toml config
// Command line flags take priority over credential.toml
func MergeCredentialConfig(cmdLineStore string, cmdLineConfig util.Configuration, cmdLinePrefix string) (*CredentialConfig, error) {
// If command line credential store is specified, use it
if cmdLineStore != "" {
glog.V(0).Infof("Using command line credential configuration: store=%s", cmdLineStore)
return GetCredentialStoreConfig(cmdLineStore, cmdLineConfig, cmdLinePrefix), nil
}
// Otherwise, try to load from credential.toml
config, err := LoadCredentialConfiguration()
if err != nil {
return nil, err
}
if config == nil {
glog.V(1).Info("No credential store configured")
}
return config, nil
}
// NewCredentialManagerWithDefaults creates a credential manager with fallback to defaults // NewCredentialManagerWithDefaults creates a credential manager with fallback to defaults
// If explicitStore is provided, it will be used regardless of credential.toml // If explicitStore is provided, it will be used regardless of credential.toml
// If explicitStore is empty, it tries credential.toml first, then defaults to "filer_etc" // If explicitStore is empty, it tries credential.toml first, then defaults to "filer_etc"

View File

@@ -207,32 +207,6 @@ func (store *FilerEtcStore) loadPoliciesFromMultiFile(ctx context.Context, polic
}) })
} }
func (store *FilerEtcStore) migratePoliciesToMultiFile(ctx context.Context, policies map[string]policy_engine.PolicyDocument) error {
glog.Infof("Migrating IAM policies to multi-file layout...")
// 1. Save all policies to individual files
for name, policy := range policies {
if err := store.savePolicy(ctx, name, policy); err != nil {
return err
}
}
// 2. Rename legacy file
return store.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
_, err := client.AtomicRenameEntry(ctx, &filer_pb.AtomicRenameEntryRequest{
OldDirectory: filer.IamConfigDirectory,
OldName: filer.IamPoliciesFile,
NewDirectory: filer.IamConfigDirectory,
NewName: IamLegacyPoliciesOldFile,
})
if err != nil {
glog.Errorf("Failed to rename legacy IAM policies file %s/%s to %s: %v",
filer.IamConfigDirectory, filer.IamPoliciesFile, IamLegacyPoliciesOldFile, err)
}
return err
})
}
func (store *FilerEtcStore) savePolicy(ctx context.Context, name string, document policy_engine.PolicyDocument) error { func (store *FilerEtcStore) savePolicy(ctx context.Context, name string, document policy_engine.PolicyDocument) error {
if err := validatePolicyName(name); err != nil { if err := validatePolicyName(name); err != nil {
return err return err

View File

@@ -1,221 +0,0 @@
package credential
import (
"context"
"fmt"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
"github.com/seaweedfs/seaweedfs/weed/util"
)
// MigrateCredentials migrates credentials from one store to another
func MigrateCredentials(fromStoreName, toStoreName CredentialStoreTypeName, configuration util.Configuration, fromPrefix, toPrefix string) error {
ctx := context.Background()
// Create source credential manager
fromCM, err := NewCredentialManager(fromStoreName, configuration, fromPrefix)
if err != nil {
return fmt.Errorf("failed to create source credential manager (%s): %v", fromStoreName, err)
}
defer fromCM.Shutdown()
// Create destination credential manager
toCM, err := NewCredentialManager(toStoreName, configuration, toPrefix)
if err != nil {
return fmt.Errorf("failed to create destination credential manager (%s): %v", toStoreName, err)
}
defer toCM.Shutdown()
// Load configuration from source
glog.Infof("Loading configuration from %s store...", fromStoreName)
config, err := fromCM.LoadConfiguration(ctx)
if err != nil {
return fmt.Errorf("failed to load configuration from source store: %w", err)
}
if config == nil || len(config.Identities) == 0 {
glog.Info("No identities found in source store")
return nil
}
glog.Infof("Found %d identities in source store", len(config.Identities))
// Migrate each identity
var migrated, failed int
for _, identity := range config.Identities {
glog.V(1).Infof("Migrating user: %s", identity.Name)
// Check if user already exists in destination
existingUser, err := toCM.GetUser(ctx, identity.Name)
if err != nil && err != ErrUserNotFound {
glog.Errorf("Failed to check if user %s exists in destination: %v", identity.Name, err)
failed++
continue
}
if existingUser != nil {
glog.Warningf("User %s already exists in destination store, skipping", identity.Name)
continue
}
// Create user in destination
err = toCM.CreateUser(ctx, identity)
if err != nil {
glog.Errorf("Failed to create user %s in destination store: %v", identity.Name, err)
failed++
continue
}
migrated++
glog.V(1).Infof("Successfully migrated user: %s", identity.Name)
}
glog.Infof("Migration completed: %d migrated, %d failed", migrated, failed)
if failed > 0 {
return fmt.Errorf("migration completed with %d failures", failed)
}
return nil
}
// ExportCredentials exports credentials from a store to a configuration
func ExportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) (*iam_pb.S3ApiConfiguration, error) {
ctx := context.Background()
// Create credential manager
cm, err := NewCredentialManager(storeName, configuration, prefix)
if err != nil {
return nil, fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
}
defer cm.Shutdown()
// Load configuration
config, err := cm.LoadConfiguration(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load configuration: %w", err)
}
return config, nil
}
// ImportCredentials imports credentials from a configuration to a store
func ImportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string, config *iam_pb.S3ApiConfiguration) error {
ctx := context.Background()
// Create credential manager
cm, err := NewCredentialManager(storeName, configuration, prefix)
if err != nil {
return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
}
defer cm.Shutdown()
// Import each identity
var imported, failed int
for _, identity := range config.Identities {
glog.V(1).Infof("Importing user: %s", identity.Name)
// Check if user already exists
existingUser, err := cm.GetUser(ctx, identity.Name)
if err != nil && err != ErrUserNotFound {
glog.Errorf("Failed to check if user %s exists: %v", identity.Name, err)
failed++
continue
}
if existingUser != nil {
glog.Warningf("User %s already exists, skipping", identity.Name)
continue
}
// Create user
err = cm.CreateUser(ctx, identity)
if err != nil {
glog.Errorf("Failed to create user %s: %v", identity.Name, err)
failed++
continue
}
imported++
glog.V(1).Infof("Successfully imported user: %s", identity.Name)
}
glog.Infof("Import completed: %d imported, %d failed", imported, failed)
if failed > 0 {
return fmt.Errorf("import completed with %d failures", failed)
}
return nil
}
// ValidateCredentials validates that all credentials in a store are accessible
func ValidateCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) error {
ctx := context.Background()
// Create credential manager
cm, err := NewCredentialManager(storeName, configuration, prefix)
if err != nil {
return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
}
defer cm.Shutdown()
// Load configuration
config, err := cm.LoadConfiguration(ctx)
if err != nil {
return fmt.Errorf("failed to load configuration: %w", err)
}
if config == nil || len(config.Identities) == 0 {
glog.Info("No identities found in store")
return nil
}
glog.Infof("Validating %d identities...", len(config.Identities))
// Validate each identity
var validated, failed int
for _, identity := range config.Identities {
// Check if user can be retrieved
user, err := cm.GetUser(ctx, identity.Name)
if err != nil {
glog.Errorf("Failed to retrieve user %s: %v", identity.Name, err)
failed++
continue
}
if user == nil {
glog.Errorf("User %s not found", identity.Name)
failed++
continue
}
// Validate access keys
for _, credential := range identity.Credentials {
accessKeyUser, err := cm.GetUserByAccessKey(ctx, credential.AccessKey)
if err != nil {
glog.Errorf("Failed to retrieve user by access key %s: %v", credential.AccessKey, err)
failed++
continue
}
if accessKeyUser == nil || accessKeyUser.Name != identity.Name {
glog.Errorf("Access key %s does not map to correct user %s", credential.AccessKey, identity.Name)
failed++
continue
}
}
validated++
glog.V(1).Infof("Successfully validated user: %s", identity.Name)
}
glog.Infof("Validation completed: %d validated, %d failed", validated, failed)
if failed > 0 {
return fmt.Errorf("validation completed with %d failures", failed)
}
return nil
}

View File

@@ -246,10 +246,6 @@ func NewLogFileEntryCollector(f *Filer, startPosition log_buffer.MessagePosition
} }
} }
func (c *LogFileEntryCollector) hasMore() bool {
return c.dayEntryQueue.Len() > 0
}
func (c *LogFileEntryCollector) collectMore(v *OrderedLogVisitor) (err error) { func (c *LogFileEntryCollector) collectMore(v *OrderedLogVisitor) (err error) {
dayEntry := c.dayEntryQueue.Dequeue() dayEntry := c.dayEntryQueue.Dequeue()
if dayEntry == nil { if dayEntry == nil {

View File

@@ -2,7 +2,6 @@ package filer
import ( import (
"context" "context"
"sync"
"github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
@@ -36,39 +35,3 @@ func Replay(filerStore FilerStore, resp *filer_pb.SubscribeMetadataResponse) err
return nil return nil
} }
// ParallelProcessDirectoryStructure processes each entry in parallel, and also ensure parent directories are processed first.
// This also assumes the parent directories are in the entryChan already.
func ParallelProcessDirectoryStructure(entryChan chan *Entry, concurrency int, eachEntryFn func(entry *Entry) error) (firstErr error) {
executors := util.NewLimitedConcurrentExecutor(concurrency)
var wg sync.WaitGroup
for entry := range entryChan {
wg.Add(1)
if entry.IsDirectory() {
func() {
defer wg.Done()
if err := eachEntryFn(entry); err != nil {
if firstErr == nil {
firstErr = err
}
}
}()
} else {
executors.Execute(func() {
defer wg.Done()
if err := eachEntryFn(entry); err != nil {
if firstErr == nil {
firstErr = err
}
}
})
}
if firstErr != nil {
break
}
}
wg.Wait()
return
}

View File

@@ -16,15 +16,6 @@ type ItemList struct {
prefix string prefix string
} }
func newItemList(client redis.UniversalClient, prefix string, store skiplist.ListStore, batchSize int) *ItemList {
return &ItemList{
skipList: skiplist.New(store),
batchSize: batchSize,
client: client,
prefix: prefix,
}
}
/* /*
Be reluctant to create new nodes. Try to fit into either previous node or next node. Be reluctant to create new nodes. Try to fit into either previous node or next node.
Prefer to add to previous node. Prefer to add to previous node.

View File

@@ -1,48 +0,0 @@
package redis_lua
import (
"github.com/redis/go-redis/v9"
"github.com/seaweedfs/seaweedfs/weed/filer"
"github.com/seaweedfs/seaweedfs/weed/util"
)
func init() {
filer.Stores = append(filer.Stores, &RedisLuaClusterStore{})
}
type RedisLuaClusterStore struct {
UniversalRedisLuaStore
}
func (store *RedisLuaClusterStore) GetName() string {
return "redis_lua_cluster"
}
func (store *RedisLuaClusterStore) Initialize(configuration util.Configuration, prefix string) (err error) {
configuration.SetDefault(prefix+"useReadOnly", false)
configuration.SetDefault(prefix+"routeByLatency", false)
return store.initialize(
configuration.GetStringSlice(prefix+"addresses"),
configuration.GetString(prefix+"username"),
configuration.GetString(prefix+"password"),
configuration.GetString(prefix+"keyPrefix"),
configuration.GetBool(prefix+"useReadOnly"),
configuration.GetBool(prefix+"routeByLatency"),
configuration.GetStringSlice(prefix+"superLargeDirectories"),
)
}
func (store *RedisLuaClusterStore) initialize(addresses []string, username string, password string, keyPrefix string, readOnly, routeByLatency bool, superLargeDirectories []string) (err error) {
store.Client = redis.NewClusterClient(&redis.ClusterOptions{
Addrs: addresses,
Username: username,
Password: password,
ReadOnly: readOnly,
RouteByLatency: routeByLatency,
})
store.keyPrefix = keyPrefix
store.loadSuperLargeDirectories(superLargeDirectories)
return
}

View File

@@ -1,48 +0,0 @@
package redis_lua
import (
"time"
"github.com/redis/go-redis/v9"
"github.com/seaweedfs/seaweedfs/weed/filer"
"github.com/seaweedfs/seaweedfs/weed/util"
)
func init() {
filer.Stores = append(filer.Stores, &RedisLuaSentinelStore{})
}
type RedisLuaSentinelStore struct {
UniversalRedisLuaStore
}
func (store *RedisLuaSentinelStore) GetName() string {
return "redis_lua_sentinel"
}
func (store *RedisLuaSentinelStore) Initialize(configuration util.Configuration, prefix string) (err error) {
return store.initialize(
configuration.GetStringSlice(prefix+"addresses"),
configuration.GetString(prefix+"masterName"),
configuration.GetString(prefix+"username"),
configuration.GetString(prefix+"password"),
configuration.GetInt(prefix+"database"),
configuration.GetString(prefix+"keyPrefix"),
)
}
func (store *RedisLuaSentinelStore) initialize(addresses []string, masterName string, username string, password string, database int, keyPrefix string) (err error) {
store.Client = redis.NewFailoverClient(&redis.FailoverOptions{
MasterName: masterName,
SentinelAddrs: addresses,
Username: username,
Password: password,
DB: database,
MinRetryBackoff: time.Millisecond * 100,
MaxRetryBackoff: time.Minute * 1,
ReadTimeout: time.Second * 30,
WriteTimeout: time.Second * 5,
})
store.keyPrefix = keyPrefix
return
}

View File

@@ -1,42 +0,0 @@
package redis_lua
import (
"github.com/redis/go-redis/v9"
"github.com/seaweedfs/seaweedfs/weed/filer"
"github.com/seaweedfs/seaweedfs/weed/util"
)
func init() {
filer.Stores = append(filer.Stores, &RedisLuaStore{})
}
type RedisLuaStore struct {
UniversalRedisLuaStore
}
func (store *RedisLuaStore) GetName() string {
return "redis_lua"
}
func (store *RedisLuaStore) Initialize(configuration util.Configuration, prefix string) (err error) {
return store.initialize(
configuration.GetString(prefix+"address"),
configuration.GetString(prefix+"username"),
configuration.GetString(prefix+"password"),
configuration.GetInt(prefix+"database"),
configuration.GetString(prefix+"keyPrefix"),
configuration.GetStringSlice(prefix+"superLargeDirectories"),
)
}
func (store *RedisLuaStore) initialize(hostPort string, username string, password string, database int, keyPrefix string, superLargeDirectories []string) (err error) {
store.Client = redis.NewClient(&redis.Options{
Addr: hostPort,
Username: username,
Password: password,
DB: database,
})
store.keyPrefix = keyPrefix
store.loadSuperLargeDirectories(superLargeDirectories)
return
}

View File

@@ -1,19 +0,0 @@
-- KEYS[1]: full path of entry
local fullpath = KEYS[1]
-- KEYS[2]: full path of entry
local fullpath_list_key = KEYS[2]
-- KEYS[3]: dir of the entry
local dir_list_key = KEYS[3]
-- ARGV[1]: isSuperLargeDirectory
local isSuperLargeDirectory = ARGV[1] == "1"
-- ARGV[2]: name of the entry
local name = ARGV[2]
redis.call("DEL", fullpath, fullpath_list_key)
if not isSuperLargeDirectory and name ~= "" then
redis.call("ZREM", dir_list_key, name)
end
return 0

View File

@@ -1,15 +0,0 @@
-- KEYS[1]: full path of entry
local fullpath = KEYS[1]
if fullpath ~= "" and string.sub(fullpath, -1) == "/" then
fullpath = string.sub(fullpath, 0, -2)
end
local files = redis.call("ZRANGE", fullpath .. "\0", "0", "-1")
for _, name in ipairs(files) do
local file_path = fullpath .. "/" .. name
redis.call("DEL", file_path, file_path .. "\0")
end
return 0

View File

@@ -1,25 +0,0 @@
package stored_procedure
import (
_ "embed"
"github.com/redis/go-redis/v9"
)
func init() {
InsertEntryScript = redis.NewScript(insertEntry)
DeleteEntryScript = redis.NewScript(deleteEntry)
DeleteFolderChildrenScript = redis.NewScript(deleteFolderChildren)
}
//go:embed insert_entry.lua
var insertEntry string
var InsertEntryScript *redis.Script
//go:embed delete_entry.lua
var deleteEntry string
var DeleteEntryScript *redis.Script
//go:embed delete_folder_children.lua
var deleteFolderChildren string
var DeleteFolderChildrenScript *redis.Script

View File

@@ -1,27 +0,0 @@
-- KEYS[1]: full path of entry
local full_path = KEYS[1]
-- KEYS[2]: dir of the entry
local dir_list_key = KEYS[2]
-- ARGV[1]: content of the entry
local entry = ARGV[1]
-- ARGV[2]: TTL of the entry
local ttlSec = tonumber(ARGV[2])
-- ARGV[3]: isSuperLargeDirectory
local isSuperLargeDirectory = ARGV[3] == "1"
-- ARGV[4]: zscore of the entry in zset
local zscore = tonumber(ARGV[4])
-- ARGV[5]: name of the entry
local name = ARGV[5]
if ttlSec > 0 then
redis.call("SET", full_path, entry, "EX", ttlSec)
else
redis.call("SET", full_path, entry)
end
if not isSuperLargeDirectory and name ~= "" then
redis.call("ZADD", dir_list_key, "NX", zscore, name)
end
return 0

View File

@@ -1,206 +0,0 @@
package redis_lua
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/seaweedfs/seaweedfs/weed/filer"
"github.com/seaweedfs/seaweedfs/weed/filer/redis_lua/stored_procedure"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
"github.com/seaweedfs/seaweedfs/weed/util"
)
const (
DIR_LIST_MARKER = "\x00"
)
type UniversalRedisLuaStore struct {
Client redis.UniversalClient
keyPrefix string
superLargeDirectoryHash map[string]bool
}
func (store *UniversalRedisLuaStore) isSuperLargeDirectory(dir string) (isSuperLargeDirectory bool) {
_, isSuperLargeDirectory = store.superLargeDirectoryHash[dir]
return
}
func (store *UniversalRedisLuaStore) loadSuperLargeDirectories(superLargeDirectories []string) {
// set directory hash
store.superLargeDirectoryHash = make(map[string]bool)
for _, dir := range superLargeDirectories {
store.superLargeDirectoryHash[dir] = true
}
}
func (store *UniversalRedisLuaStore) getKey(key string) string {
if store.keyPrefix == "" {
return key
}
return store.keyPrefix + key
}
func (store *UniversalRedisLuaStore) BeginTransaction(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (store *UniversalRedisLuaStore) CommitTransaction(ctx context.Context) error {
return nil
}
func (store *UniversalRedisLuaStore) RollbackTransaction(ctx context.Context) error {
return nil
}
func (store *UniversalRedisLuaStore) InsertEntry(ctx context.Context, entry *filer.Entry) (err error) {
value, err := entry.EncodeAttributesAndChunks()
if err != nil {
return fmt.Errorf("encoding %s %+v: %v", entry.FullPath, entry.Attr, err)
}
if len(entry.GetChunks()) > filer.CountEntryChunksForGzip {
value = util.MaybeGzipData(value)
}
dir, name := entry.FullPath.DirAndName()
err = stored_procedure.InsertEntryScript.Run(ctx, store.Client,
[]string{store.getKey(string(entry.FullPath)), store.getKey(genDirectoryListKey(dir))},
value, entry.TtlSec,
store.isSuperLargeDirectory(dir), 0, name).Err()
if err != nil {
return fmt.Errorf("persisting %s : %v", entry.FullPath, err)
}
return nil
}
func (store *UniversalRedisLuaStore) UpdateEntry(ctx context.Context, entry *filer.Entry) (err error) {
return store.InsertEntry(ctx, entry)
}
func (store *UniversalRedisLuaStore) FindEntry(ctx context.Context, fullpath util.FullPath) (entry *filer.Entry, err error) {
data, err := store.Client.Get(ctx, store.getKey(string(fullpath))).Result()
if err == redis.Nil {
return nil, filer_pb.ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get %s : %v", fullpath, err)
}
entry = &filer.Entry{
FullPath: fullpath,
}
err = entry.DecodeAttributesAndChunks(util.MaybeDecompressData([]byte(data)))
if err != nil {
return entry, fmt.Errorf("decode %s : %v", entry.FullPath, err)
}
return entry, nil
}
func (store *UniversalRedisLuaStore) DeleteEntry(ctx context.Context, fullpath util.FullPath) (err error) {
dir, name := fullpath.DirAndName()
err = stored_procedure.DeleteEntryScript.Run(ctx, store.Client,
[]string{store.getKey(string(fullpath)), store.getKey(genDirectoryListKey(string(fullpath))), store.getKey(genDirectoryListKey(dir))},
store.isSuperLargeDirectory(dir), name).Err()
if err != nil {
return fmt.Errorf("DeleteEntry %s : %v", fullpath, err)
}
return nil
}
func (store *UniversalRedisLuaStore) DeleteFolderChildren(ctx context.Context, fullpath util.FullPath) (err error) {
if store.isSuperLargeDirectory(string(fullpath)) {
return nil
}
err = stored_procedure.DeleteFolderChildrenScript.Run(ctx, store.Client,
[]string{store.getKey(string(fullpath))}).Err()
if err != nil {
return fmt.Errorf("DeleteFolderChildren %s : %v", fullpath, err)
}
return nil
}
func (store *UniversalRedisLuaStore) ListDirectoryPrefixedEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, prefix string, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) {
return lastFileName, filer.ErrUnsupportedListDirectoryPrefixed
}
func (store *UniversalRedisLuaStore) ListDirectoryEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) {
dirListKey := store.getKey(genDirectoryListKey(string(dirPath)))
min := "-"
if startFileName != "" {
if includeStartFile {
min = "[" + startFileName
} else {
min = "(" + startFileName
}
}
members, err := store.Client.ZRangeByLex(ctx, dirListKey, &redis.ZRangeBy{
Min: min,
Max: "+",
Offset: 0,
Count: limit,
}).Result()
if err != nil {
return lastFileName, fmt.Errorf("list %s : %v", dirPath, err)
}
// fetch entry meta
for _, fileName := range members {
path := util.NewFullPath(string(dirPath), fileName)
entry, err := store.FindEntry(ctx, path)
lastFileName = fileName
if err != nil {
glog.V(0).InfofCtx(ctx, "list %s : %v", path, err)
if err == filer_pb.ErrNotFound {
continue
}
} else {
if entry.TtlSec > 0 {
if entry.Attr.Crtime.Add(time.Duration(entry.TtlSec) * time.Second).Before(time.Now()) {
store.DeleteEntry(ctx, path)
continue
}
}
resEachEntryFunc, resEachEntryFuncErr := eachEntryFunc(entry)
if resEachEntryFuncErr != nil {
err = fmt.Errorf("failed to process eachEntryFunc: %w", resEachEntryFuncErr)
break
}
if !resEachEntryFunc {
break
}
}
}
return lastFileName, err
}
func genDirectoryListKey(dir string) (dirList string) {
return dir + DIR_LIST_MARKER
}
func (store *UniversalRedisLuaStore) Shutdown() {
store.Client.Close()
}

View File

@@ -1,42 +0,0 @@
package redis_lua
import (
"context"
"fmt"
"github.com/redis/go-redis/v9"
"github.com/seaweedfs/seaweedfs/weed/filer"
)
func (store *UniversalRedisLuaStore) KvPut(ctx context.Context, key []byte, value []byte) (err error) {
_, err = store.Client.Set(ctx, string(key), value, 0).Result()
if err != nil {
return fmt.Errorf("kv put: %w", err)
}
return nil
}
func (store *UniversalRedisLuaStore) KvGet(ctx context.Context, key []byte) (value []byte, err error) {
data, err := store.Client.Get(ctx, string(key)).Result()
if err == redis.Nil {
return nil, filer.ErrKvNotFound
}
return []byte(data), err
}
func (store *UniversalRedisLuaStore) KvDelete(ctx context.Context, key []byte) (err error) {
_, err = store.Client.Del(ctx, string(key)).Result()
if err != nil {
return fmt.Errorf("kv delete: %w", err)
}
return nil
}

View File

@@ -102,10 +102,6 @@ func PrepareStreamContent(masterClient wdclient.HasLookupFileIdFunction, jwtFunc
type VolumeServerJwtFunction func(fileId string) string type VolumeServerJwtFunction func(fileId string) string
func noJwtFunc(string) string {
return ""
}
type CacheInvalidator interface { type CacheInvalidator interface {
InvalidateCache(fileId string) InvalidateCache(fileId string)
} }
@@ -276,33 +272,6 @@ func writeZero(w io.Writer, size int64) (err error) {
return return
} }
func ReadAll(ctx context.Context, buffer []byte, masterClient *wdclient.MasterClient, chunks []*filer_pb.FileChunk) error {
lookupFileIdFn := func(ctx context.Context, fileId string) (targetUrls []string, err error) {
return masterClient.LookupFileId(ctx, fileId)
}
chunkViews := ViewFromChunks(ctx, lookupFileIdFn, chunks, 0, int64(len(buffer)))
idx := 0
for x := chunkViews.Front(); x != nil; x = x.Next {
chunkView := x.Value
urlStrings, err := lookupFileIdFn(ctx, chunkView.FileId)
if err != nil {
glog.V(1).InfofCtx(ctx, "operation LookupFileId %s failed, err: %v", chunkView.FileId, err)
return err
}
n, err := util_http.RetriedFetchChunkData(ctx, buffer[idx:idx+int(chunkView.ViewSize)], urlStrings, chunkView.CipherKey, chunkView.IsGzipped, chunkView.IsFullChunk(), chunkView.OffsetInChunk, chunkView.FileId)
if err != nil {
return err
}
idx += n
}
return nil
}
// ---------------- ChunkStreamReader ---------------------------------- // ---------------- ChunkStreamReader ----------------------------------
type ChunkStreamReader struct { type ChunkStreamReader struct {
head *Interval[*ChunkView] head *Interval[*ChunkView]

View File

@@ -1,281 +0,0 @@
package filer
import (
"bytes"
"context"
"errors"
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
"github.com/seaweedfs/seaweedfs/weed/wdclient"
)
// mockMasterClient implements HasLookupFileIdFunction and CacheInvalidator
type mockMasterClient struct {
lookupFunc func(ctx context.Context, fileId string) ([]string, error)
invalidatedFileIds []string
}
func (m *mockMasterClient) GetLookupFileIdFunction() wdclient.LookupFileIdFunctionType {
return m.lookupFunc
}
func (m *mockMasterClient) InvalidateCache(fileId string) {
m.invalidatedFileIds = append(m.invalidatedFileIds, fileId)
}
// Test urlSlicesEqual helper function
func TestUrlSlicesEqual(t *testing.T) {
tests := []struct {
name string
a []string
b []string
expected bool
}{
{
name: "identical slices",
a: []string{"http://server1", "http://server2"},
b: []string{"http://server1", "http://server2"},
expected: true,
},
{
name: "same URLs different order",
a: []string{"http://server1", "http://server2"},
b: []string{"http://server2", "http://server1"},
expected: true,
},
{
name: "different URLs",
a: []string{"http://server1", "http://server2"},
b: []string{"http://server1", "http://server3"},
expected: false,
},
{
name: "different lengths",
a: []string{"http://server1"},
b: []string{"http://server1", "http://server2"},
expected: false,
},
{
name: "empty slices",
a: []string{},
b: []string{},
expected: true,
},
{
name: "duplicates in both",
a: []string{"http://server1", "http://server1"},
b: []string{"http://server1", "http://server1"},
expected: true,
},
{
name: "different duplicate counts",
a: []string{"http://server1", "http://server1"},
b: []string{"http://server1", "http://server2"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := urlSlicesEqual(tt.a, tt.b)
if result != tt.expected {
t.Errorf("urlSlicesEqual(%v, %v) = %v; want %v", tt.a, tt.b, result, tt.expected)
}
})
}
}
// Test cache invalidation when read fails
func TestStreamContentWithCacheInvalidation(t *testing.T) {
ctx := context.Background()
fileId := "3,01234567890"
callCount := 0
oldUrls := []string{"http://failed-server:8080"}
newUrls := []string{"http://working-server:8080"}
mock := &mockMasterClient{
lookupFunc: func(ctx context.Context, fid string) ([]string, error) {
callCount++
if callCount == 1 {
// First call returns failing server
return oldUrls, nil
}
// After invalidation, return working server
return newUrls, nil
},
}
// Create a simple chunk
chunks := []*filer_pb.FileChunk{
{
FileId: fileId,
Offset: 0,
Size: 10,
},
}
streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
if err != nil {
t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err)
}
// Note: This test can't fully execute streamFn because it would require actual HTTP servers
// However, we can verify the setup was created correctly
if streamFn == nil {
t.Fatal("Expected non-nil stream function")
}
// Verify the lookup was called
if callCount != 1 {
t.Errorf("Expected 1 lookup call, got %d", callCount)
}
}
// Test that InvalidateCache is called on read failure
func TestCacheInvalidationInterface(t *testing.T) {
mock := &mockMasterClient{
lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
return []string{"http://server:8080"}, nil
},
}
fileId := "3,test123"
// Simulate invalidation
if invalidator, ok := interface{}(mock).(CacheInvalidator); ok {
invalidator.InvalidateCache(fileId)
} else {
t.Fatal("mockMasterClient should implement CacheInvalidator")
}
// Check that the file ID was recorded as invalidated
if len(mock.invalidatedFileIds) != 1 {
t.Fatalf("Expected 1 invalidated file ID, got %d", len(mock.invalidatedFileIds))
}
if mock.invalidatedFileIds[0] != fileId {
t.Errorf("Expected invalidated file ID %s, got %s", fileId, mock.invalidatedFileIds[0])
}
}
// Test retry logic doesn't retry with same URLs
func TestRetryLogicSkipsSameUrls(t *testing.T) {
// This test verifies that the urlSlicesEqual check prevents infinite retries
sameUrls := []string{"http://server1:8080", "http://server2:8080"}
differentUrls := []string{"http://server3:8080", "http://server4:8080"}
// Same URLs should return true (and thus skip retry)
if !urlSlicesEqual(sameUrls, sameUrls) {
t.Error("Expected same URLs to be equal")
}
// Different URLs should return false (and thus allow retry)
if urlSlicesEqual(sameUrls, differentUrls) {
t.Error("Expected different URLs to not be equal")
}
}
func TestCanceledStreamSkipsCacheInvalidation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
fileId := "3,canceled"
mock := &mockMasterClient{
lookupFunc: func(ctx context.Context, fid string) ([]string, error) {
return []string{"http://server:8080"}, nil
},
}
chunks := []*filer_pb.FileChunk{
{
FileId: fileId,
Offset: 0,
Size: 10,
},
}
streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
if err != nil {
t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err)
}
cancel()
err = streamFn(&bytes.Buffer{})
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
if len(mock.invalidatedFileIds) != 0 {
t.Fatalf("expected no cache invalidation on cancellation, got %v", mock.invalidatedFileIds)
}
}
func TestPrepareStreamContentSkipsLookupWhenContextAlreadyCanceled(t *testing.T) {
oldSchedule := getLookupFileIdBackoffSchedule
getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond}
t.Cleanup(func() {
getLookupFileIdBackoffSchedule = oldSchedule
})
ctx, cancel := context.WithCancel(context.Background())
cancel()
lookupCalls := 0
mock := &mockMasterClient{
lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
lookupCalls++
return nil, errors.New("lookup should not run")
},
}
chunks := []*filer_pb.FileChunk{
{
FileId: "3,precanceled",
Offset: 0,
Size: 10,
},
}
_, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
if lookupCalls != 0 {
t.Fatalf("expected no lookup calls after cancellation, got %d", lookupCalls)
}
}
func TestPrepareStreamContentStopsLookupRetriesAfterContextCancellation(t *testing.T) {
oldSchedule := getLookupFileIdBackoffSchedule
getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond, time.Millisecond, time.Millisecond}
t.Cleanup(func() {
getLookupFileIdBackoffSchedule = oldSchedule
})
ctx, cancel := context.WithCancel(context.Background())
lookupCalls := 0
mock := &mockMasterClient{
lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
lookupCalls++
cancel()
return nil, context.Canceled
},
}
chunks := []*filer_pb.FileChunk{
{
FileId: "3,cancel-during-lookup",
Offset: 0,
Size: 10,
},
}
_, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
if lookupCalls != 1 {
t.Fatalf("expected lookup retries to stop after cancellation, got %d calls", lookupCalls)
}
}

View File

@@ -37,11 +37,6 @@ func GenerateRandomString(length int, charset string) (string, error) {
return string(b), nil return string(b), nil
} }
// GenerateAccessKeyId generates a new access key ID.
func GenerateAccessKeyId() (string, error) {
return GenerateRandomString(AccessKeyIdLength, CharsetUpper)
}
// GenerateSecretAccessKey generates a new secret access key. // GenerateSecretAccessKey generates a new secret access key.
func GenerateSecretAccessKey() (string, error) { func GenerateSecretAccessKey() (string, error) {
return GenerateRandomString(SecretAccessKeyLength, Charset) return GenerateRandomString(SecretAccessKeyLength, Charset)
@@ -179,11 +174,3 @@ func MapToIdentitiesAction(action string) string {
return "" return ""
} }
} }
// MaskAccessKey masks an access key for logging, showing only the first 4 characters.
func MaskAccessKey(accessKeyId string) string {
if len(accessKeyId) > 4 {
return accessKeyId[:4] + "***"
}
return accessKeyId
}

View File

@@ -1,164 +0,0 @@
package iam
import (
"testing"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/stretchr/testify/assert"
)
func TestHash(t *testing.T) {
input := "test"
result := Hash(&input)
assert.NotEmpty(t, result)
assert.Len(t, result, 40) // SHA1 hex is 40 chars
// Same input should produce same hash
result2 := Hash(&input)
assert.Equal(t, result, result2)
// Different input should produce different hash
different := "different"
result3 := Hash(&different)
assert.NotEqual(t, result, result3)
}
func TestGenerateRandomString(t *testing.T) {
// Valid generation
result, err := GenerateRandomString(10, CharsetUpper)
assert.NoError(t, err)
assert.Len(t, result, 10)
// Different calls should produce different results (with high probability)
result2, err := GenerateRandomString(10, CharsetUpper)
assert.NoError(t, err)
assert.NotEqual(t, result, result2)
// Invalid length
_, err = GenerateRandomString(0, CharsetUpper)
assert.Error(t, err)
_, err = GenerateRandomString(-1, CharsetUpper)
assert.Error(t, err)
// Empty charset
_, err = GenerateRandomString(10, "")
assert.Error(t, err)
}
func TestGenerateAccessKeyId(t *testing.T) {
keyId, err := GenerateAccessKeyId()
assert.NoError(t, err)
assert.Len(t, keyId, AccessKeyIdLength)
}
func TestGenerateSecretAccessKey(t *testing.T) {
secretKey, err := GenerateSecretAccessKey()
assert.NoError(t, err)
assert.Len(t, secretKey, SecretAccessKeyLength)
}
func TestGenerateSecretAccessKey_URLSafe(t *testing.T) {
// Generate multiple keys to increase probability of catching unsafe chars
for i := 0; i < 100; i++ {
secretKey, err := GenerateSecretAccessKey()
assert.NoError(t, err)
// Verify no URL-unsafe characters that would cause authentication issues
assert.NotContains(t, secretKey, "/", "Secret key should not contain /")
assert.NotContains(t, secretKey, "+", "Secret key should not contain +")
// Verify only expected characters are present
for _, char := range secretKey {
assert.Contains(t, Charset, string(char), "Secret key contains unexpected character: %c", char)
}
}
}
func TestStringSlicesEqual(t *testing.T) {
tests := []struct {
a []string
b []string
expected bool
}{
{[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true},
{[]string{"c", "b", "a"}, []string{"a", "b", "c"}, true}, // Order independent
{[]string{"a", "b"}, []string{"a", "b", "c"}, false},
{[]string{}, []string{}, true},
{nil, nil, true},
{[]string{"a"}, []string{"b"}, false},
}
for _, test := range tests {
result := StringSlicesEqual(test.a, test.b)
assert.Equal(t, test.expected, result)
}
}
func TestMapToStatementAction(t *testing.T) {
tests := []struct {
input string
expected string
}{
{StatementActionAdmin, s3_constants.ACTION_ADMIN},
{StatementActionWrite, s3_constants.ACTION_WRITE},
{StatementActionRead, s3_constants.ACTION_READ},
{StatementActionList, s3_constants.ACTION_LIST},
{StatementActionDelete, s3_constants.ACTION_DELETE_BUCKET},
// Test fine-grained S3 action mappings (Issue #7864)
{"DeleteObject", s3_constants.ACTION_WRITE},
{"s3:DeleteObject", s3_constants.ACTION_WRITE},
{"PutObject", s3_constants.ACTION_WRITE},
{"s3:PutObject", s3_constants.ACTION_WRITE},
{"GetObject", s3_constants.ACTION_READ},
{"s3:GetObject", s3_constants.ACTION_READ},
{"ListBucket", s3_constants.ACTION_LIST},
{"s3:ListBucket", s3_constants.ACTION_LIST},
{"PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
{"s3:PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
{"GetObjectAcl", s3_constants.ACTION_READ_ACP},
{"s3:GetObjectAcl", s3_constants.ACTION_READ_ACP},
{"unknown", ""},
}
for _, test := range tests {
result := MapToStatementAction(test.input)
assert.Equal(t, test.expected, result, "Failed for input: %s", test.input)
}
}
func TestMapToIdentitiesAction(t *testing.T) {
tests := []struct {
input string
expected string
}{
{s3_constants.ACTION_ADMIN, StatementActionAdmin},
{s3_constants.ACTION_WRITE, StatementActionWrite},
{s3_constants.ACTION_READ, StatementActionRead},
{s3_constants.ACTION_LIST, StatementActionList},
{s3_constants.ACTION_DELETE_BUCKET, StatementActionDelete},
{"unknown", ""},
}
for _, test := range tests {
result := MapToIdentitiesAction(test.input)
assert.Equal(t, test.expected, result)
}
}
func TestMaskAccessKey(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"AKIAIOSFODNN7EXAMPLE", "AKIA***"},
{"AKIA", "AKIA"},
{"AKI", "AKI"},
{"", ""},
}
for _, test := range tests {
result := MaskAccessKey(test.input)
assert.Equal(t, test.expected, result)
}
}

View File

@@ -202,32 +202,6 @@ func (m *IAMManager) getFilerAddress() string {
return "" // Fallback to empty string if no provider is set return "" // Fallback to empty string if no provider is set
} }
// createRoleStore creates a role store based on configuration
func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) {
if config == nil {
// Default to generic cached filer role store when no config provided
return NewGenericCachedRoleStore(nil, nil)
}
switch config.StoreType {
case "", "filer":
// Check if caching is explicitly disabled
if config.StoreConfig != nil {
if noCache, ok := config.StoreConfig["noCache"].(bool); ok && noCache {
return NewFilerRoleStore(config.StoreConfig, nil)
}
}
// Default to generic cached filer store for better performance
return NewGenericCachedRoleStore(config.StoreConfig, nil)
case "cached-filer", "generic-cached":
return NewGenericCachedRoleStore(config.StoreConfig, nil)
case "memory":
return NewMemoryRoleStore(), nil
default:
return nil, fmt.Errorf("unsupported role store type: %s", config.StoreType)
}
}
// createRoleStoreWithProvider creates a role store with a filer address provider function // createRoleStoreWithProvider creates a role store with a filer address provider function
func (m *IAMManager) createRoleStoreWithProvider(config *RoleStoreConfig, filerAddressProvider func() string) (RoleStore, error) { func (m *IAMManager) createRoleStoreWithProvider(config *RoleStoreConfig, filerAddressProvider func() string) (RoleStore, error) {
if config == nil { if config == nil {

View File

@@ -388,157 +388,3 @@ type CachedFilerRoleStoreConfig struct {
ListTTL string `json:"listTtl,omitempty"` // e.g., "1m", "30s" ListTTL string `json:"listTtl,omitempty"` // e.g., "1m", "30s"
MaxCacheSize int `json:"maxCacheSize,omitempty"` // Maximum number of cached roles MaxCacheSize int `json:"maxCacheSize,omitempty"` // Maximum number of cached roles
} }
// NewCachedFilerRoleStore creates a new cached filer-based role store
func NewCachedFilerRoleStore(config map[string]interface{}) (*CachedFilerRoleStore, error) {
// Create underlying filer store
filerStore, err := NewFilerRoleStore(config, nil)
if err != nil {
return nil, fmt.Errorf("failed to create filer role store: %w", err)
}
// Parse cache configuration with defaults
cacheTTL := 5 * time.Minute // Default 5 minutes for role cache
listTTL := 1 * time.Minute // Default 1 minute for list cache
maxCacheSize := 1000 // Default max 1000 cached roles
if config != nil {
if ttlStr, ok := config["ttl"].(string); ok && ttlStr != "" {
if parsed, err := time.ParseDuration(ttlStr); err == nil {
cacheTTL = parsed
}
}
if listTTLStr, ok := config["listTtl"].(string); ok && listTTLStr != "" {
if parsed, err := time.ParseDuration(listTTLStr); err == nil {
listTTL = parsed
}
}
if maxSize, ok := config["maxCacheSize"].(int); ok && maxSize > 0 {
maxCacheSize = maxSize
}
}
// Create ccache instances with appropriate configurations
pruneCount := int64(maxCacheSize) >> 3
if pruneCount <= 0 {
pruneCount = 100
}
store := &CachedFilerRoleStore{
filerStore: filerStore,
cache: ccache.New(ccache.Configure().MaxSize(int64(maxCacheSize)).ItemsToPrune(uint32(pruneCount))),
listCache: ccache.New(ccache.Configure().MaxSize(100).ItemsToPrune(10)), // Smaller cache for lists
ttl: cacheTTL,
listTTL: listTTL,
}
glog.V(2).Infof("Initialized CachedFilerRoleStore with TTL %v, List TTL %v, Max Cache Size %d",
cacheTTL, listTTL, maxCacheSize)
return store, nil
}
// StoreRole stores a role definition and invalidates the cache
func (c *CachedFilerRoleStore) StoreRole(ctx context.Context, filerAddress string, roleName string, role *RoleDefinition) error {
// Store in filer
err := c.filerStore.StoreRole(ctx, filerAddress, roleName, role)
if err != nil {
return err
}
// Invalidate cache entries
c.cache.Delete(roleName)
c.listCache.Clear() // Invalidate list cache
glog.V(3).Infof("Stored and invalidated cache for role %s", roleName)
return nil
}
// GetRole retrieves a role definition with caching
func (c *CachedFilerRoleStore) GetRole(ctx context.Context, filerAddress string, roleName string) (*RoleDefinition, error) {
// Try to get from cache first
item := c.cache.Get(roleName)
if item != nil {
// Cache hit - return cached role (DO NOT extend TTL)
role := item.Value().(*RoleDefinition)
glog.V(4).Infof("Cache hit for role %s", roleName)
return copyRoleDefinition(role), nil
}
// Cache miss - fetch from filer
glog.V(4).Infof("Cache miss for role %s, fetching from filer", roleName)
role, err := c.filerStore.GetRole(ctx, filerAddress, roleName)
if err != nil {
return nil, err
}
// Cache the result with TTL
c.cache.Set(roleName, copyRoleDefinition(role), c.ttl)
glog.V(3).Infof("Cached role %s with TTL %v", roleName, c.ttl)
return role, nil
}
// ListRoles lists all role names with caching
func (c *CachedFilerRoleStore) ListRoles(ctx context.Context, filerAddress string) ([]string, error) {
// Use a constant key for the role list cache
const listCacheKey = "role_list"
// Try to get from list cache first
item := c.listCache.Get(listCacheKey)
if item != nil {
// Cache hit - return cached list (DO NOT extend TTL)
roles := item.Value().([]string)
glog.V(4).Infof("List cache hit, returning %d roles", len(roles))
return append([]string(nil), roles...), nil // Return a copy
}
// Cache miss - fetch from filer
glog.V(4).Infof("List cache miss, fetching from filer")
roles, err := c.filerStore.ListRoles(ctx, filerAddress)
if err != nil {
return nil, err
}
// Cache the result with TTL (store a copy)
rolesCopy := append([]string(nil), roles...)
c.listCache.Set(listCacheKey, rolesCopy, c.listTTL)
glog.V(3).Infof("Cached role list with %d entries, TTL %v", len(roles), c.listTTL)
return roles, nil
}
// DeleteRole deletes a role definition and invalidates the cache
func (c *CachedFilerRoleStore) DeleteRole(ctx context.Context, filerAddress string, roleName string) error {
// Delete from filer
err := c.filerStore.DeleteRole(ctx, filerAddress, roleName)
if err != nil {
return err
}
// Invalidate cache entries
c.cache.Delete(roleName)
c.listCache.Clear() // Invalidate list cache
glog.V(3).Infof("Deleted and invalidated cache for role %s", roleName)
return nil
}
// ClearCache clears all cached entries (for testing or manual cache invalidation)
func (c *CachedFilerRoleStore) ClearCache() {
c.cache.Clear()
c.listCache.Clear()
glog.V(2).Infof("Cleared all role cache entries")
}
// GetCacheStats returns cache statistics
func (c *CachedFilerRoleStore) GetCacheStats() map[string]interface{} {
return map[string]interface{}{
"roleCache": map[string]interface{}{
"size": c.cache.ItemCount(),
"ttl": c.ttl.String(),
},
"listCache": map[string]interface{}{
"size": c.listCache.ItemCount(),
"ttl": c.listTTL.String(),
},
}
}

View File

@@ -1,687 +0,0 @@
package policy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConditionSetOperators(t *testing.T) {
engine := setupTestPolicyEngine(t)
t.Run("ForAnyValue:StringEquals", func(t *testing.T) {
trustPolicy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowOIDC",
Effect: "Allow",
Action: []string{"sts:AssumeRoleWithWebIdentity"},
Condition: map[string]map[string]interface{}{
"ForAnyValue:StringEquals": {
"oidc:roles": []string{"Dev.SeaweedFS.TestBucket.ReadWrite", "Dev.SeaweedFS.Admin"},
},
},
},
},
}
// Match: Admin is in the requested roles
evalCtxMatch := &EvaluationContext{
Principal: "web-identity-user",
Action: "sts:AssumeRoleWithWebIdentity",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"Dev.SeaweedFS.Admin", "OtherRole"},
},
}
resultMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxMatch)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultMatch.Effect)
// No Match
evalCtxNoMatch := &EvaluationContext{
Principal: "web-identity-user",
Action: "sts:AssumeRoleWithWebIdentity",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"OtherRole1", "OtherRole2"},
},
}
resultNoMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxNoMatch)
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
// No Match: Empty context for ForAnyValue (should deny)
evalCtxEmpty := &EvaluationContext{
Principal: "web-identity-user",
Action: "sts:AssumeRoleWithWebIdentity",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{},
},
}
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxEmpty)
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultEmpty.Effect, "ForAnyValue should deny when context is empty")
})
t.Run("ForAllValues:StringEquals", func(t *testing.T) {
trustPolicyAll := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowOIDCAll",
Effect: "Allow",
Action: []string{"sts:AssumeRoleWithWebIdentity"},
Condition: map[string]map[string]interface{}{
"ForAllValues:StringEquals": {
"oidc:roles": []string{"RoleA", "RoleB", "RoleC"},
},
},
},
},
}
// Match: All requested roles ARE in the allowed set
evalCtxAllMatch := &EvaluationContext{
Principal: "web-identity-user",
Action: "sts:AssumeRoleWithWebIdentity",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"RoleA", "RoleB"},
},
}
resultAllMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllMatch)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultAllMatch.Effect)
// Fail: RoleD is NOT in the allowed set
evalCtxAllFail := &EvaluationContext{
Principal: "web-identity-user",
Action: "sts:AssumeRoleWithWebIdentity",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"RoleA", "RoleD"},
},
}
resultAllFail, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllFail)
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultAllFail.Effect)
// Vacuously true: Request has NO roles
evalCtxEmpty := &EvaluationContext{
Principal: "web-identity-user",
Action: "sts:AssumeRoleWithWebIdentity",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{},
},
}
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxEmpty)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultEmpty.Effect)
})
t.Run("ForAllValues:NumericEqualsVacuouslyTrue", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowNumericAll",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAllValues:NumericEquals": {
"aws:MultiFactorAuthAge": []string{"3600", "7200"},
},
},
},
},
}
// Vacuously true: Request has NO MFA age info
evalCtxEmpty := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:MultiFactorAuthAge": []string{},
},
}
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when numeric context is empty for ForAllValues")
})
t.Run("ForAllValues:BoolVacuouslyTrue", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowBoolAll",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAllValues:Bool": {
"aws:SecureTransport": "true",
},
},
},
},
}
// Vacuously true
evalCtxEmpty := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:SecureTransport": []interface{}{},
},
}
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when bool context is empty for ForAllValues")
})
t.Run("ForAllValues:DateVacuouslyTrue", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowDateAll",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAllValues:DateGreaterThan": {
"aws:CurrentTime": "2020-01-01T00:00:00Z",
},
},
},
},
}
// Vacuously true
evalCtxEmpty := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:CurrentTime": []interface{}{},
},
}
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when date context is empty for ForAllValues")
})
t.Run("ForAllValues:DateWithLabelsAsStrings", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowDateStrings",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAllValues:DateGreaterThan": {
"aws:CurrentTime": "2020-01-01T00:00:00Z",
},
},
},
},
}
evalCtx := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:CurrentTime": []string{"2021-01-01T00:00:00Z", "2022-01-01T00:00:00Z"},
},
}
result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
require.NoError(t, err)
assert.Equal(t, EffectAllow, result.Effect, "Should allow when date context is a slice of strings")
})
t.Run("ForAllValues:BoolWithLabelsAsStrings", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowBoolStrings",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAllValues:Bool": {
"aws:SecureTransport": "true",
},
},
},
},
}
evalCtx := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:SecureTransport": []string{"true", "true"},
},
}
result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
require.NoError(t, err)
assert.Equal(t, EffectAllow, result.Effect, "Should allow when bool context is a slice of strings")
})
t.Run("StringEqualsIgnoreCaseWithVariable", func(t *testing.T) {
policyDoc := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowVar",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"arn:aws:s3:::bucket/*"},
Condition: map[string]map[string]interface{}{
"StringEqualsIgnoreCase": {
"s3:prefix": "${aws:username}/",
},
},
},
},
}
err := engine.AddPolicy("", "var-policy", policyDoc)
require.NoError(t, err)
evalCtx := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/ALICE/file.txt",
RequestContext: map[string]interface{}{
"s3:prefix": "ALICE/",
"aws:username": "alice",
},
}
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"var-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, result.Effect, "Should allow when variable expands and matches case-insensitively")
})
t.Run("StringLike:CaseSensitivity", func(t *testing.T) {
policyDoc := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowCaseSensitiveLike",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"arn:aws:s3:::bucket/*"},
Condition: map[string]map[string]interface{}{
"StringLike": {
"s3:prefix": "Project/*",
},
},
},
},
}
err := engine.AddPolicy("", "like-policy", policyDoc)
require.NoError(t, err)
// Match: Case sensitive match
evalCtxMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/Project/file.txt",
RequestContext: map[string]interface{}{
"s3:prefix": "Project/data",
},
}
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"like-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultMatch.Effect, "Should allow when case matches exactly")
// Fail: Case insensitive match (should fail for StringLike)
evalCtxFail := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/project/file.txt",
RequestContext: map[string]interface{}{
"s3:prefix": "project/data", // lowercase 'p'
},
}
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"like-policy"})
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when case does not match for StringLike")
})
t.Run("NumericNotEquals:Logic", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "DenySpecificAges",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"ForAllValues:NumericNotEquals": {
"aws:MultiFactorAuthAge": []string{"3600", "7200"},
},
},
},
},
}
err := engine.AddPolicy("", "numeric-not-equals-policy", policy)
require.NoError(t, err)
// Fail: One age matches an excluded value (3600)
evalCtxFail := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:MultiFactorAuthAge": []string{"3600", "1800"},
},
}
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"numeric-not-equals-policy"})
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one age matches an excluded value")
// Pass: No age matches any excluded value
evalCtxPass := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:MultiFactorAuthAge": []string{"1800", "900"},
},
}
resultPass, err := engine.Evaluate(context.Background(), "", evalCtxPass, []string{"numeric-not-equals-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultPass.Effect, "Should allow when no age matches excluded values")
})
t.Run("DateNotEquals:Logic", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "DenySpecificTimes",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"ForAllValues:DateNotEquals": {
"aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-02T00:00:00Z"},
},
},
},
},
}
err := engine.AddPolicy("", "date-not-equals-policy", policy)
require.NoError(t, err)
// Fail: One time matches an excluded value
evalCtxFail := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-03T00:00:00Z"},
},
}
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"date-not-equals-policy"})
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one date matches an excluded value")
})
t.Run("IpAddress:SetOperators", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowSpecificIPs",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"ForAllValues:IpAddress": {
"aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.1"},
},
},
},
},
}
err := engine.AddPolicy("", "ip-set-policy", policy)
require.NoError(t, err)
// Match: All source IPs are in allowed ranges
evalCtxMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": []string{"192.168.1.10", "10.0.0.1"},
},
}
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-set-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultMatch.Effect)
// Fail: One source IP is NOT in allowed ranges
evalCtxFail := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
},
}
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"ip-set-policy"})
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultFail.Effect)
// ForAnyValue: IPAddress
policyAny := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowAnySpecificIPs",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"ForAnyValue:IpAddress": {
"aws:SourceIp": []string{"192.168.1.0/24"},
},
},
},
},
}
err = engine.AddPolicy("", "ip-any-policy", policyAny)
require.NoError(t, err)
evalCtxAnyMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
},
}
resultAnyMatch, err := engine.Evaluate(context.Background(), "", evalCtxAnyMatch, []string{"ip-any-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultAnyMatch.Effect)
})
t.Run("IpAddress:SingleStringValue", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowSingleIP",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"IpAddress": {
"aws:SourceIp": "192.168.1.1",
},
},
},
},
}
err := engine.AddPolicy("", "ip-single-policy", policy)
require.NoError(t, err)
evalCtxMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": "192.168.1.1",
},
}
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-single-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultMatch.Effect)
evalCtxNoMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": "10.0.0.1",
},
}
resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-single-policy"})
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
})
t.Run("Bool:StringSlicePolicyValues", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowWithBoolStrings",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"Bool": {
"aws:SecureTransport": []string{"true", "false"},
},
},
},
},
}
err := engine.AddPolicy("", "bool-string-slice-policy", policy)
require.NoError(t, err)
evalCtx := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SecureTransport": "true",
},
}
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"bool-string-slice-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, result.Effect)
})
t.Run("StringEqualsIgnoreCase:StringSlicePolicyValues", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowWithIgnoreCaseStrings",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"StringEqualsIgnoreCase": {
"s3:x-amz-server-side-encryption": []string{"AES256", "aws:kms"},
},
},
},
},
}
err := engine.AddPolicy("", "string-ignorecase-slice-policy", policy)
require.NoError(t, err)
evalCtx := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"s3:x-amz-server-side-encryption": "aes256",
},
}
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"string-ignorecase-slice-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, result.Effect)
})
t.Run("IpAddress:CustomContextKey", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowCustomIPKey",
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"*"},
Condition: map[string]map[string]interface{}{
"IpAddress": {
"custom:VpcIp": "10.0.0.0/16",
},
},
},
},
}
err := engine.AddPolicy("", "ip-custom-key-policy", policy)
require.NoError(t, err)
evalCtxMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"custom:VpcIp": "10.0.5.1",
},
}
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-custom-key-policy"})
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultMatch.Effect)
evalCtxNoMatch := &EvaluationContext{
Principal: "user",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::bucket/file.txt",
RequestContext: map[string]interface{}{
"custom:VpcIp": "192.168.1.1",
},
}
resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-custom-key-policy"})
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
})
}

View File

@@ -1,101 +0,0 @@
package policy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNegationSetOperators(t *testing.T) {
engine := setupTestPolicyEngine(t)
t.Run("ForAllValues:StringNotEquals", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "DenyAdmin",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAllValues:StringNotEquals": {
"oidc:roles": []string{"Admin"},
},
},
},
},
}
// All roles are NOT "Admin" -> Should Allow
evalCtxAllow := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"User", "Developer"},
},
}
resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when ALL roles satisfy StringNotEquals Admin")
// One role is "Admin" -> Should Deny
evalCtxDeny := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"Admin", "User"},
},
}
resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when one role is Admin and fails StringNotEquals")
})
t.Run("ForAnyValue:StringNotEquals", func(t *testing.T) {
policy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "Requirement",
Effect: "Allow",
Action: []string{"sts:AssumeRole"},
Condition: map[string]map[string]interface{}{
"ForAnyValue:StringNotEquals": {
"oidc:roles": []string{"Prohibited"},
},
},
},
},
}
// At least one role is NOT prohibited -> Should Allow
evalCtxAllow := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"Prohibited", "Allowed"},
},
}
resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
require.NoError(t, err)
assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when at least one role is NOT Prohibited")
// All roles are Prohibited -> Should Deny
evalCtxDeny := &EvaluationContext{
Principal: "user",
Action: "sts:AssumeRole",
Resource: "arn:aws:iam::role/test-role",
RequestContext: map[string]interface{}{
"oidc:roles": []string{"Prohibited", "Prohibited"},
},
}
resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
require.NoError(t, err)
assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when ALL roles are Prohibited")
})
}

View File

@@ -1155,11 +1155,6 @@ func ValidatePolicyDocumentWithType(policy *PolicyDocument, policyType string) e
return nil return nil
} }
// validateStatement validates a single statement (for backward compatibility)
func validateStatement(statement *Statement) error {
return validateStatementWithType(statement, "resource")
}
// validateStatementWithType validates a single statement based on policy type // validateStatementWithType validates a single statement based on policy type
func validateStatementWithType(statement *Statement, policyType string) error { func validateStatementWithType(statement *Statement, policyType string) error {
if statement.Effect != "Allow" && statement.Effect != "Deny" { if statement.Effect != "Allow" && statement.Effect != "Deny" {
@@ -1198,29 +1193,6 @@ func validateStatementWithType(statement *Statement, policyType string) error {
return nil return nil
} }
// matchResource checks if a resource pattern matches a requested resource
// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
func matchResource(pattern, resource string) bool {
if pattern == resource {
return true
}
// Handle simple suffix wildcard (backward compatibility)
if strings.HasSuffix(pattern, "*") {
prefix := pattern[:len(pattern)-1]
return strings.HasPrefix(resource, prefix)
}
// For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
matched, err := filepath.Match(pattern, resource)
if err != nil {
// Fallback to exact match if pattern is malformed
return pattern == resource
}
return matched
}
// awsIAMMatch performs AWS IAM-compliant pattern matching with case-insensitivity and policy variable support // awsIAMMatch performs AWS IAM-compliant pattern matching with case-insensitivity and policy variable support
func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool { func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool {
// Step 1: Substitute policy variables (e.g., ${aws:username}, ${saml:username}) // Step 1: Substitute policy variables (e.g., ${aws:username}, ${saml:username})
@@ -1274,16 +1246,6 @@ func expandPolicyVariables(pattern string, evalCtx *EvaluationContext) string {
return result return result
} }
// getContextValue safely gets a value from the evaluation context
func getContextValue(evalCtx *EvaluationContext, key, defaultValue string) string {
if value, exists := evalCtx.RequestContext[key]; exists {
if str, ok := value.(string); ok {
return str
}
}
return defaultValue
}
// AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM // AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
func AwsWildcardMatch(pattern, value string) bool { func AwsWildcardMatch(pattern, value string) bool {
// Create regex pattern key for caching // Create regex pattern key for caching
@@ -1322,29 +1284,6 @@ func AwsWildcardMatch(pattern, value string) bool {
return regex.MatchString(value) return regex.MatchString(value)
} }
// matchAction checks if an action pattern matches a requested action
// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
func matchAction(pattern, action string) bool {
if pattern == action {
return true
}
// Handle simple suffix wildcard (backward compatibility)
if strings.HasSuffix(pattern, "*") {
prefix := pattern[:len(pattern)-1]
return strings.HasPrefix(action, prefix)
}
// For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
matched, err := filepath.Match(pattern, action)
if err != nil {
// Fallback to exact match if pattern is malformed
return pattern == action
}
return matched
}
// evaluateStringConditionIgnoreCase evaluates string conditions with case insensitivity // evaluateStringConditionIgnoreCase evaluates string conditions with case insensitivity
func (e *PolicyEngine) evaluateStringConditionIgnoreCase(block map[string]interface{}, evalCtx *EvaluationContext, shouldMatch bool, useWildcard bool, forAllValues bool) bool { func (e *PolicyEngine) evaluateStringConditionIgnoreCase(block map[string]interface{}, evalCtx *EvaluationContext, shouldMatch bool, useWildcard bool, forAllValues bool) bool {
for key, expectedValues := range block { for key, expectedValues := range block {

View File

@@ -1,421 +0,0 @@
package policy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestPrincipalMatching tests the matchesPrincipal method
func TestPrincipalMatching(t *testing.T) {
engine := setupTestPolicyEngine(t)
tests := []struct {
name string
principal interface{}
evalCtx *EvaluationContext
want bool
}{
{
name: "plain wildcard principal",
principal: "*",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{},
},
want: true,
},
{
name: "structured wildcard federated principal",
principal: map[string]interface{}{
"Federated": "*",
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{},
},
want: true,
},
{
name: "wildcard in array",
principal: map[string]interface{}{
"Federated": []interface{}{"specific-provider", "*"},
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{},
},
want: true,
},
{
name: "specific federated provider match",
principal: map[string]interface{}{
"Federated": "https://example.com/oidc",
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://example.com/oidc",
},
},
want: true,
},
{
name: "specific federated provider no match",
principal: map[string]interface{}{
"Federated": "https://example.com/oidc",
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://other.com/oidc",
},
},
want: false,
},
{
name: "array with specific provider match",
principal: map[string]interface{}{
"Federated": []string{"https://provider1.com", "https://provider2.com"},
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://provider2.com",
},
},
want: true,
},
{
name: "AWS principal match",
principal: map[string]interface{}{
"AWS": "arn:aws:iam::123456789012:user/alice",
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:PrincipalArn": "arn:aws:iam::123456789012:user/alice",
},
},
want: true,
},
{
name: "Service principal match",
principal: map[string]interface{}{
"Service": "s3.amazonaws.com",
},
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:PrincipalServiceName": "s3.amazonaws.com",
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := engine.matchesPrincipal(tt.principal, tt.evalCtx)
assert.Equal(t, tt.want, result, "Principal matching failed for: %s", tt.name)
})
}
}
// TestEvaluatePrincipalValue tests the evaluatePrincipalValue method
func TestEvaluatePrincipalValue(t *testing.T) {
engine := setupTestPolicyEngine(t)
tests := []struct {
name string
principalValue interface{}
contextKey string
evalCtx *EvaluationContext
want bool
}{
{
name: "wildcard string",
principalValue: "*",
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{},
},
want: true,
},
{
name: "specific string match",
principalValue: "https://example.com",
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://example.com",
},
},
want: true,
},
{
name: "specific string no match",
principalValue: "https://example.com",
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://other.com",
},
},
want: false,
},
{
name: "wildcard in array",
principalValue: []interface{}{"provider1", "*"},
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{},
},
want: true,
},
{
name: "array match",
principalValue: []string{"provider1", "provider2", "provider3"},
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "provider2",
},
},
want: true,
},
{
name: "array no match",
principalValue: []string{"provider1", "provider2"},
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "provider3",
},
},
want: false,
},
{
name: "missing context key",
principalValue: "specific-value",
contextKey: "aws:FederatedProvider",
evalCtx: &EvaluationContext{
RequestContext: map[string]interface{}{},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := engine.evaluatePrincipalValue(tt.principalValue, tt.evalCtx, tt.contextKey)
assert.Equal(t, tt.want, result, "Principal value evaluation failed for: %s", tt.name)
})
}
}
// TestTrustPolicyEvaluation tests the EvaluateTrustPolicy method
func TestTrustPolicyEvaluation(t *testing.T) {
engine := setupTestPolicyEngine(t)
tests := []struct {
name string
trustPolicy *PolicyDocument
evalCtx *EvaluationContext
wantEffect Effect
wantErr bool
}{
{
name: "wildcard federated principal allows any provider",
trustPolicy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Allow",
Principal: map[string]interface{}{
"Federated": "*",
},
Action: []string{"sts:AssumeRoleWithWebIdentity"},
},
},
},
evalCtx: &EvaluationContext{
Action: "sts:AssumeRoleWithWebIdentity",
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://any-provider.com",
},
},
wantEffect: EffectAllow,
wantErr: false,
},
{
name: "specific federated principal matches",
trustPolicy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Allow",
Principal: map[string]interface{}{
"Federated": "https://example.com/oidc",
},
Action: []string{"sts:AssumeRoleWithWebIdentity"},
},
},
},
evalCtx: &EvaluationContext{
Action: "sts:AssumeRoleWithWebIdentity",
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://example.com/oidc",
},
},
wantEffect: EffectAllow,
wantErr: false,
},
{
name: "specific federated principal does not match",
trustPolicy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Allow",
Principal: map[string]interface{}{
"Federated": "https://example.com/oidc",
},
Action: []string{"sts:AssumeRoleWithWebIdentity"},
},
},
},
evalCtx: &EvaluationContext{
Action: "sts:AssumeRoleWithWebIdentity",
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://other.com/oidc",
},
},
wantEffect: EffectDeny,
wantErr: false,
},
{
name: "plain wildcard principal",
trustPolicy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Allow",
Principal: "*",
Action: []string{"sts:AssumeRoleWithWebIdentity"},
},
},
},
evalCtx: &EvaluationContext{
Action: "sts:AssumeRoleWithWebIdentity",
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://any-provider.com",
},
},
wantEffect: EffectAllow,
wantErr: false,
},
{
name: "trust policy with conditions",
trustPolicy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Allow",
Principal: map[string]interface{}{
"Federated": "*",
},
Action: []string{"sts:AssumeRoleWithWebIdentity"},
Condition: map[string]map[string]interface{}{
"StringEquals": {
"oidc:aud": "my-app-id",
},
},
},
},
},
evalCtx: &EvaluationContext{
Action: "sts:AssumeRoleWithWebIdentity",
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://provider.com",
"oidc:aud": "my-app-id",
},
},
wantEffect: EffectAllow,
wantErr: false,
},
{
name: "trust policy condition not met",
trustPolicy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Allow",
Principal: map[string]interface{}{
"Federated": "*",
},
Action: []string{"sts:AssumeRoleWithWebIdentity"},
Condition: map[string]map[string]interface{}{
"StringEquals": {
"oidc:aud": "my-app-id",
},
},
},
},
},
evalCtx: &EvaluationContext{
Action: "sts:AssumeRoleWithWebIdentity",
RequestContext: map[string]interface{}{
"aws:FederatedProvider": "https://provider.com",
"oidc:aud": "wrong-app-id",
},
},
wantEffect: EffectDeny,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := engine.EvaluateTrustPolicy(context.Background(), tt.trustPolicy, tt.evalCtx)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.wantEffect, result.Effect, "Trust policy evaluation failed for: %s", tt.name)
}
})
}
}
// TestGetPrincipalContextKey tests the context key mapping
func TestGetPrincipalContextKey(t *testing.T) {
tests := []struct {
name string
principalType string
want string
}{
{
name: "Federated principal",
principalType: "Federated",
want: "aws:FederatedProvider",
},
{
name: "AWS principal",
principalType: "AWS",
want: "aws:PrincipalArn",
},
{
name: "Service principal",
principalType: "Service",
want: "aws:PrincipalServiceName",
},
{
name: "Custom principal type",
principalType: "CustomType",
want: "aws:PrincipalCustomType",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getPrincipalContextKey(tt.principalType)
assert.Equal(t, tt.want, result, "Context key mapping failed for: %s", tt.name)
})
}
}

View File

@@ -1,426 +0,0 @@
package policy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestPolicyEngineInitialization tests policy engine initialization
func TestPolicyEngineInitialization(t *testing.T) {
tests := []struct {
name string
config *PolicyEngineConfig
wantErr bool
}{
{
name: "valid config",
config: &PolicyEngineConfig{
DefaultEffect: "Deny",
StoreType: "memory",
},
wantErr: false,
},
{
name: "invalid default effect",
config: &PolicyEngineConfig{
DefaultEffect: "Invalid",
StoreType: "memory",
},
wantErr: true,
},
{
name: "nil config",
config: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewPolicyEngine()
err := engine.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, engine.IsInitialized())
}
})
}
}
// TestPolicyDocumentValidation tests policy document structure validation
func TestPolicyDocumentValidation(t *testing.T) {
tests := []struct {
name string
policy *PolicyDocument
wantErr bool
errorMsg string
}{
{
name: "valid policy document",
policy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowS3Read",
Effect: "Allow",
Action: []string{"s3:GetObject", "s3:ListBucket"},
Resource: []string{"arn:aws:s3:::mybucket/*"},
},
},
},
wantErr: false,
},
{
name: "missing version",
policy: &PolicyDocument{
Statement: []Statement{
{
Effect: "Allow",
Action: []string{"s3:GetObject"},
Resource: []string{"arn:aws:s3:::mybucket/*"},
},
},
},
wantErr: true,
errorMsg: "version is required",
},
{
name: "empty statements",
policy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{},
},
wantErr: true,
errorMsg: "at least one statement is required",
},
{
name: "invalid effect",
policy: &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Effect: "Maybe",
Action: []string{"s3:GetObject"},
Resource: []string{"arn:aws:s3:::mybucket/*"},
},
},
},
wantErr: true,
errorMsg: "invalid effect",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePolicyDocument(tt.policy)
if tt.wantErr {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
// TestPolicyEvaluation tests policy evaluation logic
func TestPolicyEvaluation(t *testing.T) {
engine := setupTestPolicyEngine(t)
// Add test policies
readPolicy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowS3Read",
Effect: "Allow",
Action: []string{"s3:GetObject", "s3:ListBucket"},
Resource: []string{
"arn:aws:s3:::public-bucket/*", // For object operations
"arn:aws:s3:::public-bucket", // For bucket operations
},
},
},
}
err := engine.AddPolicy("", "read-policy", readPolicy)
require.NoError(t, err)
denyPolicy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "DenyS3Delete",
Effect: "Deny",
Action: []string{"s3:DeleteObject"},
Resource: []string{"arn:aws:s3:::*"},
},
},
}
err = engine.AddPolicy("", "deny-policy", denyPolicy)
require.NoError(t, err)
tests := []struct {
name string
context *EvaluationContext
policies []string
want Effect
}{
{
name: "allow read access",
context: &EvaluationContext{
Principal: "user:alice",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::public-bucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": "192.168.1.100",
},
},
policies: []string{"read-policy"},
want: EffectAllow,
},
{
name: "deny delete access (explicit deny)",
context: &EvaluationContext{
Principal: "user:alice",
Action: "s3:DeleteObject",
Resource: "arn:aws:s3:::public-bucket/file.txt",
},
policies: []string{"read-policy", "deny-policy"},
want: EffectDeny,
},
{
name: "deny by default (no matching policy)",
context: &EvaluationContext{
Principal: "user:alice",
Action: "s3:PutObject",
Resource: "arn:aws:s3:::public-bucket/file.txt",
},
policies: []string{"read-policy"},
want: EffectDeny,
},
{
name: "allow with wildcard action",
context: &EvaluationContext{
Principal: "user:admin",
Action: "s3:ListBucket",
Resource: "arn:aws:s3:::public-bucket",
},
policies: []string{"read-policy"},
want: EffectAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := engine.Evaluate(context.Background(), "", tt.context, tt.policies)
assert.NoError(t, err)
assert.Equal(t, tt.want, result.Effect)
// Verify evaluation details
assert.NotNil(t, result.EvaluationDetails)
assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action)
assert.Equal(t, tt.context.Resource, result.EvaluationDetails.Resource)
})
}
}
// TestConditionEvaluation tests policy conditions
func TestConditionEvaluation(t *testing.T) {
engine := setupTestPolicyEngine(t)
// Policy with IP address condition
conditionalPolicy := &PolicyDocument{
Version: "2012-10-17",
Statement: []Statement{
{
Sid: "AllowFromOfficeIP",
Effect: "Allow",
Action: []string{"s3:*"},
Resource: []string{"arn:aws:s3:::*"},
Condition: map[string]map[string]interface{}{
"IpAddress": {
"aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.0/8"},
},
},
},
},
}
err := engine.AddPolicy("", "ip-conditional", conditionalPolicy)
require.NoError(t, err)
tests := []struct {
name string
context *EvaluationContext
want Effect
}{
{
name: "allow from office IP",
context: &EvaluationContext{
Principal: "user:alice",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::mybucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": "192.168.1.100",
},
},
want: EffectAllow,
},
{
name: "deny from external IP",
context: &EvaluationContext{
Principal: "user:alice",
Action: "s3:GetObject",
Resource: "arn:aws:s3:::mybucket/file.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": "8.8.8.8",
},
},
want: EffectDeny,
},
{
name: "allow from internal IP",
context: &EvaluationContext{
Principal: "user:alice",
Action: "s3:PutObject",
Resource: "arn:aws:s3:::mybucket/newfile.txt",
RequestContext: map[string]interface{}{
"aws:SourceIp": "10.1.2.3",
},
},
want: EffectAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := engine.Evaluate(context.Background(), "", tt.context, []string{"ip-conditional"})
assert.NoError(t, err)
assert.Equal(t, tt.want, result.Effect)
})
}
}
// TestResourceMatching tests resource ARN matching
func TestResourceMatching(t *testing.T) {
tests := []struct {
name string
policyResource string
requestResource string
want bool
}{
{
name: "exact match",
policyResource: "arn:aws:s3:::mybucket/file.txt",
requestResource: "arn:aws:s3:::mybucket/file.txt",
want: true,
},
{
name: "wildcard match",
policyResource: "arn:aws:s3:::mybucket/*",
requestResource: "arn:aws:s3:::mybucket/folder/file.txt",
want: true,
},
{
name: "bucket wildcard",
policyResource: "arn:aws:s3:::*",
requestResource: "arn:aws:s3:::anybucket/file.txt",
want: true,
},
{
name: "no match different bucket",
policyResource: "arn:aws:s3:::mybucket/*",
requestResource: "arn:aws:s3:::otherbucket/file.txt",
want: false,
},
{
name: "prefix match",
policyResource: "arn:aws:s3:::mybucket/documents/*",
requestResource: "arn:aws:s3:::mybucket/documents/secret.txt",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchResource(tt.policyResource, tt.requestResource)
assert.Equal(t, tt.want, result)
})
}
}
// TestActionMatching tests action pattern matching
func TestActionMatching(t *testing.T) {
tests := []struct {
name string
policyAction string
requestAction string
want bool
}{
{
name: "exact match",
policyAction: "s3:GetObject",
requestAction: "s3:GetObject",
want: true,
},
{
name: "wildcard service",
policyAction: "s3:*",
requestAction: "s3:PutObject",
want: true,
},
{
name: "wildcard all",
policyAction: "*",
requestAction: "filer:CreateEntry",
want: true,
},
{
name: "prefix match",
policyAction: "s3:Get*",
requestAction: "s3:GetObject",
want: true,
},
{
name: "no match different service",
policyAction: "s3:GetObject",
requestAction: "filer:GetEntry",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchAction(tt.policyAction, tt.requestAction)
assert.Equal(t, tt.want, result)
})
}
}
// Helper function to set up test policy engine
func setupTestPolicyEngine(t *testing.T) *PolicyEngine {
engine := NewPolicyEngine()
config := &PolicyEngineConfig{
DefaultEffect: "Deny",
StoreType: "memory",
}
err := engine.Initialize(config)
require.NoError(t, err)
return engine
}

View File

@@ -1,246 +0,0 @@
package providers
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIdentityProviderInterface tests the core identity provider interface
func TestIdentityProviderInterface(t *testing.T) {
tests := []struct {
name string
provider IdentityProvider
wantErr bool
}{
// We'll add test cases as we implement providers
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test provider name
name := tt.provider.Name()
assert.NotEmpty(t, name, "Provider name should not be empty")
// Test initialization
err := tt.provider.Initialize(nil)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
// Test authentication with invalid token
ctx := context.Background()
_, err = tt.provider.Authenticate(ctx, "invalid-token")
assert.Error(t, err, "Should fail with invalid token")
})
}
}
// TestExternalIdentityValidation tests external identity structure validation
func TestExternalIdentityValidation(t *testing.T) {
tests := []struct {
name string
identity *ExternalIdentity
wantErr bool
}{
{
name: "valid identity",
identity: &ExternalIdentity{
UserID: "user123",
Email: "user@example.com",
DisplayName: "Test User",
Groups: []string{"group1", "group2"},
Attributes: map[string]string{"dept": "engineering"},
Provider: "test-provider",
},
wantErr: false,
},
{
name: "missing user id",
identity: &ExternalIdentity{
Email: "user@example.com",
Provider: "test-provider",
},
wantErr: true,
},
{
name: "missing provider",
identity: &ExternalIdentity{
UserID: "user123",
Email: "user@example.com",
},
wantErr: true,
},
{
name: "invalid email",
identity: &ExternalIdentity{
UserID: "user123",
Email: "invalid-email",
Provider: "test-provider",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.identity.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestTokenClaimsValidation tests token claims structure
func TestTokenClaimsValidation(t *testing.T) {
tests := []struct {
name string
claims *TokenClaims
valid bool
}{
{
name: "valid claims",
claims: &TokenClaims{
Subject: "user123",
Issuer: "https://provider.example.com",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now().Add(-time.Minute),
Claims: map[string]interface{}{"email": "user@example.com"},
},
valid: true,
},
{
name: "expired token",
claims: &TokenClaims{
Subject: "user123",
Issuer: "https://provider.example.com",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(-time.Hour), // Expired
IssuedAt: time.Now().Add(-time.Hour * 2),
Claims: map[string]interface{}{"email": "user@example.com"},
},
valid: false,
},
{
name: "future issued token",
claims: &TokenClaims{
Subject: "user123",
Issuer: "https://provider.example.com",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now().Add(time.Hour), // Future
Claims: map[string]interface{}{"email": "user@example.com"},
},
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
valid := tt.claims.IsValid()
assert.Equal(t, tt.valid, valid)
})
}
}
// TestProviderRegistry tests provider registration and discovery
func TestProviderRegistry(t *testing.T) {
// Clear registry for test
registry := NewProviderRegistry()
t.Run("register provider", func(t *testing.T) {
mockProvider := &MockProvider{name: "test-provider"}
err := registry.RegisterProvider(mockProvider)
assert.NoError(t, err)
// Test duplicate registration
err = registry.RegisterProvider(mockProvider)
assert.Error(t, err, "Should not allow duplicate registration")
})
t.Run("get provider", func(t *testing.T) {
provider, exists := registry.GetProvider("test-provider")
assert.True(t, exists)
assert.Equal(t, "test-provider", provider.Name())
// Test non-existent provider
_, exists = registry.GetProvider("non-existent")
assert.False(t, exists)
})
t.Run("list providers", func(t *testing.T) {
providers := registry.ListProviders()
assert.Len(t, providers, 1)
assert.Equal(t, "test-provider", providers[0])
})
}
// MockProvider for testing
type MockProvider struct {
name string
initialized bool
shouldError bool
}
func (m *MockProvider) Name() string {
return m.name
}
func (m *MockProvider) Initialize(config interface{}) error {
if m.shouldError {
return assert.AnError
}
m.initialized = true
return nil
}
func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) {
if !m.initialized {
return nil, assert.AnError
}
if token == "invalid-token" {
return nil, assert.AnError
}
return &ExternalIdentity{
UserID: "test-user",
Email: "test@example.com",
DisplayName: "Test User",
Provider: m.name,
}, nil
}
func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) {
if !m.initialized || userID == "" {
return nil, assert.AnError
}
return &ExternalIdentity{
UserID: userID,
Email: userID + "@example.com",
DisplayName: "User " + userID,
Provider: m.name,
}, nil
}
func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) {
if !m.initialized || token == "invalid-token" {
return nil, assert.AnError
}
return &TokenClaims{
Subject: "test-user",
Issuer: "test-issuer",
Audience: "seaweedfs",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now(),
Claims: map[string]interface{}{"email": "test@example.com"},
}, nil
}

View File

@@ -1,109 +0,0 @@
package providers
import (
"fmt"
"sync"
)
// ProviderRegistry manages registered identity providers
type ProviderRegistry struct {
mu sync.RWMutex
providers map[string]IdentityProvider
}
// NewProviderRegistry creates a new provider registry
func NewProviderRegistry() *ProviderRegistry {
return &ProviderRegistry{
providers: make(map[string]IdentityProvider),
}
}
// RegisterProvider registers a new identity provider
func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error {
if provider == nil {
return fmt.Errorf("provider cannot be nil")
}
name := provider.Name()
if name == "" {
return fmt.Errorf("provider name cannot be empty")
}
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.providers[name]; exists {
return fmt.Errorf("provider %s is already registered", name)
}
r.providers[name] = provider
return nil
}
// GetProvider retrieves a provider by name
func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
provider, exists := r.providers[name]
return provider, exists
}
// ListProviders returns all registered provider names
func (r *ProviderRegistry) ListProviders() []string {
r.mu.RLock()
defer r.mu.RUnlock()
var names []string
for name := range r.providers {
names = append(names, name)
}
return names
}
// UnregisterProvider removes a provider from the registry
func (r *ProviderRegistry) UnregisterProvider(name string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.providers[name]; !exists {
return fmt.Errorf("provider %s is not registered", name)
}
delete(r.providers, name)
return nil
}
// Clear removes all providers from the registry
func (r *ProviderRegistry) Clear() {
r.mu.Lock()
defer r.mu.Unlock()
r.providers = make(map[string]IdentityProvider)
}
// GetProviderCount returns the number of registered providers
func (r *ProviderRegistry) GetProviderCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.providers)
}
// Default global registry
var defaultRegistry = NewProviderRegistry()
// RegisterProvider registers a provider in the default registry
func RegisterProvider(provider IdentityProvider) error {
return defaultRegistry.RegisterProvider(provider)
}
// GetProvider retrieves a provider from the default registry
func GetProvider(name string) (IdentityProvider, bool) {
return defaultRegistry.GetProvider(name)
}
// ListProviders returns all provider names from the default registry
func ListProviders() []string {
return defaultRegistry.ListProviders()
}

View File

@@ -1,503 +0,0 @@
package sts
import (
"context"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/oidc"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test-only constants for mock providers
const (
ProviderTypeMock = "mock"
)
// createMockOIDCProvider creates a mock OIDC provider for testing
// This is only available in test builds
func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) {
// Convert config to OIDC format
factory := NewProviderFactory()
oidcConfig, err := factory.convertToOIDCConfig(config)
if err != nil {
return nil, err
}
// Set default values for mock provider if not provided
if oidcConfig.Issuer == "" {
oidcConfig.Issuer = "http://localhost:9999"
}
provider := oidc.NewMockOIDCProvider(name)
if err := provider.Initialize(oidcConfig); err != nil {
return nil, err
}
// Set up default test data for the mock provider
provider.SetupDefaultTestData()
return provider, nil
}
// createMockJWT creates a test JWT token with the specified issuer for mock provider testing
func createMockJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance
// can be used and validated by other STS instances in a distributed environment
func TestCrossInstanceTokenUsage(t *testing.T) {
ctx := context.Background()
// Dummy filer address for testing
// Common configuration that would be shared across all instances in production
sharedConfig := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "distributed-sts-cluster", // SAME across all instances
SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
Providers: []*ProviderConfig{
{
Name: "company-oidc",
Type: ProviderTypeOIDC,
Enabled: true,
Config: map[string]interface{}{
ConfigFieldIssuer: "https://sso.company.com/realms/production",
ConfigFieldClientID: "seaweedfs-cluster",
ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
},
},
},
}
// Create multiple STS instances simulating different S3 gateway instances
instanceA := NewSTSService() // e.g., s3-gateway-1
instanceB := NewSTSService() // e.g., s3-gateway-2
instanceC := NewSTSService() // e.g., s3-gateway-3
// Initialize all instances with IDENTICAL configuration
err := instanceA.Initialize(sharedConfig)
require.NoError(t, err, "Instance A should initialize")
err = instanceB.Initialize(sharedConfig)
require.NoError(t, err, "Instance B should initialize")
err = instanceC.Initialize(sharedConfig)
require.NoError(t, err, "Instance C should initialize")
// Set up mock trust policy validator for all instances (required for STS testing)
mockValidator := &MockTrustPolicyValidator{}
instanceA.SetTrustPolicyValidator(mockValidator)
instanceB.SetTrustPolicyValidator(mockValidator)
instanceC.SetTrustPolicyValidator(mockValidator)
// Manually register mock provider for testing (not available in production)
mockProviderConfig := map[string]interface{}{
ConfigFieldIssuer: "http://test-mock:9999",
ConfigFieldClientID: TestClientID,
}
mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
instanceA.RegisterProvider(mockProviderA)
instanceB.RegisterProvider(mockProviderB)
instanceC.RegisterProvider(mockProviderC)
// Test 1: Token generated on Instance A can be validated on Instance B & C
t.Run("cross_instance_token_validation", func(t *testing.T) {
// Generate session token on Instance A
sessionId := TestSessionID
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err, "Instance A should generate token")
// Validate token on Instance B
claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
require.NoError(t, err, "Instance B should validate token from Instance A")
assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
// Validate same token on Instance C
claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA)
require.NoError(t, err, "Instance C should validate token from Instance A")
assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
// All instances should extract identical claims
assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId)
assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix())
assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix())
})
// Test 2: Complete assume role flow across instances
t.Run("cross_instance_assume_role_flow", func(t *testing.T) {
// Step 1: User authenticates and assumes role on Instance A
// Create a valid JWT token for the mock provider
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/CrossInstanceTestRole",
WebIdentityToken: mockToken, // JWT token for mock provider
RoleSessionName: "cross-instance-test-session",
DurationSeconds: int64ToPtr(3600),
}
// Instance A processes assume role request
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, err, "Instance A should process assume role")
sessionToken := responseFromA.Credentials.SessionToken
accessKeyId := responseFromA.Credentials.AccessKeyId
secretAccessKey := responseFromA.Credentials.SecretAccessKey
// Verify response structure
assert.NotEmpty(t, sessionToken, "Should have session token")
assert.NotEmpty(t, accessKeyId, "Should have access key ID")
assert.NotEmpty(t, secretAccessKey, "Should have secret access key")
assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
// Step 2: Use session token on Instance B (different instance)
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Instance B should validate session token from Instance A")
assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
// Step 3: Use same session token on Instance C (yet another instance)
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Instance C should validate session token from Instance A")
// All instances should return identical session information
assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId)
assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName)
assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn)
assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject)
assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider)
})
// Test 3: Session revocation across instances
t.Run("cross_instance_session_revocation", func(t *testing.T) {
// Create session on Instance A
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/RevocationTestRole",
WebIdentityToken: mockToken,
RoleSessionName: "revocation-test-session",
}
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token works on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Token should be valid on Instance B initially")
// Validate session on Instance C to verify cross-instance token compatibility
_, err = instanceC.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Instance C should be able to validate session token")
// In a stateless JWT system, tokens remain valid on all instances since they're self-contained
// No revocation is possible without breaking the stateless architecture
_, err = instanceA.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)")
// Verify token is still valid on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)")
})
// Test 4: Provider consistency across instances
t.Run("provider_consistency_affects_token_generation", func(t *testing.T) {
// All instances should have same providers and be able to process same OIDC tokens
providerNamesA := instanceA.getProviderNames()
providerNamesB := instanceB.getProviderNames()
providerNamesC := instanceC.getProviderNames()
assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers")
assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers")
// All instances should be able to process same web identity token
testToken := createMockJWT(t, "http://test-mock:9999", "test-user")
// Try to assume role with same token on different instances
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/ProviderTestRole",
WebIdentityToken: testToken,
RoleSessionName: "provider-consistency-test",
}
// Should work on any instance
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, errA, "Instance A should process OIDC token")
require.NoError(t, errB, "Instance B should process OIDC token")
require.NoError(t, errC, "Instance C should process OIDC token")
// All should return valid responses (sessions will have different IDs but same structure)
assert.NotEmpty(t, responseA.Credentials.SessionToken)
assert.NotEmpty(t, responseB.Credentials.SessionToken)
assert.NotEmpty(t, responseC.Credentials.SessionToken)
})
}
// TestSTSDistributedConfigurationRequirements tests the configuration requirements
// for cross-instance token compatibility
func TestSTSDistributedConfigurationRequirements(t *testing.T) {
_ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
t.Run("same_signing_key_required", func(t *testing.T) {
// Instance A with signing key 1
configA := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-1-32-characters-long"),
}
// Instance B with different signing key
configB := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT!
}
instanceA := NewSTSService()
instanceB := NewSTSService()
err := instanceA.Initialize(configA)
require.NoError(t, err)
err = instanceB.Initialize(configB)
require.NoError(t, err)
// Generate token on Instance A
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance A should validate its own token
_, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.NoError(t, err, "Instance A should validate own token")
// Instance B should REJECT token due to different signing key
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.Error(t, err, "Instance B should reject token with different signing key")
assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
})
t.Run("same_issuer_required", func(t *testing.T) {
sharedSigningKey := []byte("shared-signing-key-32-characters-lo")
// Instance A with issuer 1
configA := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-cluster-1",
SigningKey: sharedSigningKey,
}
// Instance B with different issuer
configB := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-cluster-2", // DIFFERENT!
SigningKey: sharedSigningKey,
}
instanceA := NewSTSService()
instanceB := NewSTSService()
err := instanceA.Initialize(configA)
require.NoError(t, err)
err = instanceB.Initialize(configB)
require.NoError(t, err)
// Generate token on Instance A
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance B should REJECT token due to different issuer
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.Error(t, err, "Instance B should reject token with different issuer")
assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
})
t.Run("identical_configuration_required", func(t *testing.T) {
// Identical configuration
identicalConfig := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "production-sts-cluster",
SigningKey: []byte("production-signing-key-32-chars-l"),
}
// Create multiple instances with identical config
instances := make([]*STSService, 5)
for i := 0; i < 5; i++ {
instances[i] = NewSTSService()
err := instances[i].Initialize(identicalConfig)
require.NoError(t, err, "Instance %d should initialize", i)
}
// Generate token on Instance 0
sessionId := "multi-instance-test"
expiresAt := time.Now().Add(time.Hour)
token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// All other instances should validate the token
for i := 1; i < 5; i++ {
claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token)
require.NoError(t, err, "Instance %d should validate token", i)
assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
}
})
}
// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
func TestSTSRealWorldDistributedScenarios(t *testing.T) {
ctx := context.Background()
t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
// Simulate real production scenario:
// 1. User authenticates with OIDC provider
// 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1
// 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer
// 4. All instances should handle the session token correctly
productionConfig := &STSConfig{
TokenDuration: FlexibleDuration{2 * time.Hour},
MaxSessionLength: FlexibleDuration{24 * time.Hour},
Issuer: "seaweedfs-production-sts",
SigningKey: []byte("prod-signing-key-32-characters-lon"),
Providers: []*ProviderConfig{
{
Name: "corporate-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://sso.company.com/realms/production",
"clientId": "seaweedfs-prod-cluster",
"clientSecret": "supersecret-prod-key",
"scopes": []string{"openid", "profile", "email", "groups"},
},
},
},
}
// Create 3 S3 Gateway instances behind load balancer
gateway1 := NewSTSService()
gateway2 := NewSTSService()
gateway3 := NewSTSService()
err := gateway1.Initialize(productionConfig)
require.NoError(t, err)
err = gateway2.Initialize(productionConfig)
require.NoError(t, err)
err = gateway3.Initialize(productionConfig)
require.NoError(t, err)
// Set up mock trust policy validator for all gateway instances
mockValidator := &MockTrustPolicyValidator{}
gateway1.SetTrustPolicyValidator(mockValidator)
gateway2.SetTrustPolicyValidator(mockValidator)
gateway3.SetTrustPolicyValidator(mockValidator)
// Manually register mock provider for testing (not available in production)
mockProviderConfig := map[string]interface{}{
ConfigFieldIssuer: "http://test-mock:9999",
ConfigFieldClientID: "test-client-id",
}
mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
gateway1.RegisterProvider(mockProvider1)
gateway2.RegisterProvider(mockProvider2)
gateway3.RegisterProvider(mockProvider3)
// Step 1: User authenticates and hits Gateway 1 for AssumeRole
mockToken := createMockJWT(t, "http://test-mock:9999", "production-user")
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/ProductionS3User",
WebIdentityToken: mockToken, // JWT token from mock provider
RoleSessionName: "user-production-session",
DurationSeconds: int64ToPtr(7200), // 2 hours
}
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, err, "Gateway 1 should handle AssumeRole")
sessionToken := stsResponse.Credentials.SessionToken
accessKey := stsResponse.Credentials.AccessKeyId
secretKey := stsResponse.Credentials.SecretAccessKey
// Step 2: User makes S3 requests that hit different gateways via load balancer
// Simulate S3 request validation on Gateway 2
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
assert.Equal(t, "arn:aws:iam::role/ProductionS3User", sessionInfo2.RoleArn)
// Simulate S3 request validation on Gateway 3
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")
// Step 3: Verify credentials are consistent
assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent")
assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent")
// Step 4: Session expiration should be honored across all instances
assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired")
assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
// Step 5: Token should be identical when parsed
claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken)
require.NoError(t, err)
claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken)
require.NoError(t, err)
assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")
assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match")
})
}
// Helper function to convert int64 to pointer
func int64ToPtr(i int64) *int64 {
return &i
}

View File

@@ -1,340 +0,0 @@
package sts
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDistributedSTSService verifies that multiple STS instances with identical configurations
// behave consistently across distributed environments
func TestDistributedSTSService(t *testing.T) {
ctx := context.Background()
// Common configuration for all instances
commonConfig := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "distributed-sts-test",
SigningKey: []byte("test-signing-key-32-characters-long"),
Providers: []*ProviderConfig{
{
Name: "keycloak-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "http://keycloak:8080/realms/seaweedfs-test",
"clientId": "seaweedfs-s3",
"jwksUri": "http://keycloak:8080/realms/seaweedfs-test/protocol/openid-connect/certs",
},
},
{
Name: "disabled-ldap",
Type: "oidc", // Use OIDC as placeholder since LDAP isn't implemented
Enabled: false,
Config: map[string]interface{}{
"issuer": "ldap://company.com",
"clientId": "ldap-client",
},
},
},
}
// Create multiple STS instances simulating distributed deployment
instance1 := NewSTSService()
instance2 := NewSTSService()
instance3 := NewSTSService()
// Initialize all instances with identical configuration
err := instance1.Initialize(commonConfig)
require.NoError(t, err, "Instance 1 should initialize successfully")
err = instance2.Initialize(commonConfig)
require.NoError(t, err, "Instance 2 should initialize successfully")
err = instance3.Initialize(commonConfig)
require.NoError(t, err, "Instance 3 should initialize successfully")
// Manually register mock providers for testing (not available in production)
mockProviderConfig := map[string]interface{}{
"issuer": "http://localhost:9999",
"clientId": "test-client",
}
mockProvider1, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
require.NoError(t, err)
mockProvider2, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
require.NoError(t, err)
mockProvider3, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
require.NoError(t, err)
instance1.RegisterProvider(mockProvider1)
instance2.RegisterProvider(mockProvider2)
instance3.RegisterProvider(mockProvider3)
// Verify all instances have identical provider configurations
t.Run("provider_consistency", func(t *testing.T) {
// All instances should have same number of providers
assert.Len(t, instance1.providers, 2, "Instance 1 should have 2 enabled providers")
assert.Len(t, instance2.providers, 2, "Instance 2 should have 2 enabled providers")
assert.Len(t, instance3.providers, 2, "Instance 3 should have 2 enabled providers")
// All instances should have same provider names
instance1Names := instance1.getProviderNames()
instance2Names := instance2.getProviderNames()
instance3Names := instance3.getProviderNames()
assert.ElementsMatch(t, instance1Names, instance2Names, "Instance 1 and 2 should have same providers")
assert.ElementsMatch(t, instance2Names, instance3Names, "Instance 2 and 3 should have same providers")
// Verify specific providers exist on all instances
expectedProviders := []string{"keycloak-oidc", "test-mock-provider"}
assert.ElementsMatch(t, instance1Names, expectedProviders, "Instance 1 should have expected providers")
assert.ElementsMatch(t, instance2Names, expectedProviders, "Instance 2 should have expected providers")
assert.ElementsMatch(t, instance3Names, expectedProviders, "Instance 3 should have expected providers")
// Verify disabled providers are not loaded
assert.NotContains(t, instance1Names, "disabled-ldap", "Disabled providers should not be loaded")
assert.NotContains(t, instance2Names, "disabled-ldap", "Disabled providers should not be loaded")
assert.NotContains(t, instance3Names, "disabled-ldap", "Disabled providers should not be loaded")
})
// Test token generation consistency across instances
t.Run("token_generation_consistency", func(t *testing.T) {
sessionId := "test-session-123"
expiresAt := time.Now().Add(time.Hour)
// Generate tokens from different instances
token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err1, "Instance 1 token generation should succeed")
require.NoError(t, err2, "Instance 2 token generation should succeed")
require.NoError(t, err3, "Instance 3 token generation should succeed")
// All tokens should be different (due to timestamp variations)
// But they should all be valid JWTs with same signing key
assert.NotEmpty(t, token1)
assert.NotEmpty(t, token2)
assert.NotEmpty(t, token3)
})
// Test token validation consistency - any instance should validate tokens from any other instance
t.Run("cross_instance_token_validation", func(t *testing.T) {
sessionId := "cross-validation-session"
expiresAt := time.Now().Add(time.Hour)
// Generate token on instance 1
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Validate on all instances
claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token)
claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token)
claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token)
require.NoError(t, err1, "Instance 1 should validate token from instance 1")
require.NoError(t, err2, "Instance 2 should validate token from instance 1")
require.NoError(t, err3, "Instance 3 should validate token from instance 1")
// All instances should extract same session ID
assert.Equal(t, sessionId, claims1.SessionId)
assert.Equal(t, sessionId, claims2.SessionId)
assert.Equal(t, sessionId, claims3.SessionId)
assert.Equal(t, claims1.SessionId, claims2.SessionId)
assert.Equal(t, claims2.SessionId, claims3.SessionId)
})
// Test provider access consistency
t.Run("provider_access_consistency", func(t *testing.T) {
// All instances should be able to access the same providers
provider1, exists1 := instance1.providers["test-mock-provider"]
provider2, exists2 := instance2.providers["test-mock-provider"]
provider3, exists3 := instance3.providers["test-mock-provider"]
assert.True(t, exists1, "Instance 1 should have test-mock-provider")
assert.True(t, exists2, "Instance 2 should have test-mock-provider")
assert.True(t, exists3, "Instance 3 should have test-mock-provider")
assert.Equal(t, provider1.Name(), provider2.Name())
assert.Equal(t, provider2.Name(), provider3.Name())
// Test authentication with the mock provider on all instances
testToken := "valid_test_token"
identity1, err1 := provider1.Authenticate(ctx, testToken)
identity2, err2 := provider2.Authenticate(ctx, testToken)
identity3, err3 := provider3.Authenticate(ctx, testToken)
require.NoError(t, err1, "Instance 1 provider should authenticate successfully")
require.NoError(t, err2, "Instance 2 provider should authenticate successfully")
require.NoError(t, err3, "Instance 3 provider should authenticate successfully")
// All instances should return identical identity information
assert.Equal(t, identity1.UserID, identity2.UserID)
assert.Equal(t, identity2.UserID, identity3.UserID)
assert.Equal(t, identity1.Email, identity2.Email)
assert.Equal(t, identity2.Email, identity3.Email)
assert.Equal(t, identity1.Provider, identity2.Provider)
assert.Equal(t, identity2.Provider, identity3.Provider)
})
}
// TestSTSConfigurationValidation tests configuration validation for distributed deployments
func TestSTSConfigurationValidation(t *testing.T) {
t.Run("consistent_signing_keys_required", func(t *testing.T) {
// Different signing keys should result in incompatible token validation
config1 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-1-32-characters-long"),
}
config2 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-2-32-characters-long"), // Different key!
}
instance1 := NewSTSService()
instance2 := NewSTSService()
err1 := instance1.Initialize(config1)
err2 := instance2.Initialize(config2)
require.NoError(t, err1)
require.NoError(t, err2)
// Generate token on instance 1
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance 1 should validate its own token
_, err = instance1.GetTokenGenerator().ValidateSessionToken(token)
assert.NoError(t, err, "Instance 1 should validate its own token")
// Instance 2 should reject token from instance 1 (different signing key)
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
assert.Error(t, err, "Instance 2 should reject token with different signing key")
})
t.Run("consistent_issuer_required", func(t *testing.T) {
// Different issuers should result in incompatible tokens
commonSigningKey := []byte("shared-signing-key-32-characters-lo")
config1 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-instance-1",
SigningKey: commonSigningKey,
}
config2 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-instance-2", // Different issuer!
SigningKey: commonSigningKey,
}
instance1 := NewSTSService()
instance2 := NewSTSService()
err1 := instance1.Initialize(config1)
err2 := instance2.Initialize(config2)
require.NoError(t, err1)
require.NoError(t, err2)
// Generate token on instance 1
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance 2 should reject token due to issuer mismatch
// (Even though signing key is the same, issuer validation will fail)
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
assert.Error(t, err, "Instance 2 should reject token with different issuer")
})
}
// TestProviderFactoryDistributed tests the provider factory in distributed scenarios
func TestProviderFactoryDistributed(t *testing.T) {
factory := NewProviderFactory()
// Simulate configuration that would be identical across all instances
configs := []*ProviderConfig{
{
Name: "production-keycloak",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://keycloak.company.com/realms/seaweedfs",
"clientId": "seaweedfs-prod",
"clientSecret": "super-secret-key",
"jwksUri": "https://keycloak.company.com/realms/seaweedfs/protocol/openid-connect/certs",
"scopes": []string{"openid", "profile", "email", "roles"},
},
},
{
Name: "backup-oidc",
Type: "oidc",
Enabled: false, // Disabled by default
Config: map[string]interface{}{
"issuer": "https://backup-oidc.company.com",
"clientId": "seaweedfs-backup",
},
},
}
// Create providers multiple times (simulating multiple instances)
providers1, err1 := factory.LoadProvidersFromConfig(configs)
providers2, err2 := factory.LoadProvidersFromConfig(configs)
providers3, err3 := factory.LoadProvidersFromConfig(configs)
require.NoError(t, err1, "First load should succeed")
require.NoError(t, err2, "Second load should succeed")
require.NoError(t, err3, "Third load should succeed")
// All instances should have same provider counts
assert.Len(t, providers1, 1, "First instance should have 1 enabled provider")
assert.Len(t, providers2, 1, "Second instance should have 1 enabled provider")
assert.Len(t, providers3, 1, "Third instance should have 1 enabled provider")
// All instances should have same provider names
names1 := make([]string, 0, len(providers1))
names2 := make([]string, 0, len(providers2))
names3 := make([]string, 0, len(providers3))
for name := range providers1 {
names1 = append(names1, name)
}
for name := range providers2 {
names2 = append(names2, name)
}
for name := range providers3 {
names3 = append(names3, name)
}
assert.ElementsMatch(t, names1, names2, "Instance 1 and 2 should have same provider names")
assert.ElementsMatch(t, names2, names3, "Instance 2 and 3 should have same provider names")
// Verify specific providers
expectedProviders := []string{"production-keycloak"}
assert.ElementsMatch(t, names1, expectedProviders, "Should have expected enabled providers")
// Verify disabled providers are not included
assert.NotContains(t, names1, "backup-oidc", "Disabled providers should not be loaded")
assert.NotContains(t, names2, "backup-oidc", "Disabled providers should not be loaded")
assert.NotContains(t, names3, "backup-oidc", "Disabled providers should not be loaded")
}

View File

@@ -274,69 +274,3 @@ func (f *ProviderFactory) convertToRoleMapping(value interface{}) (*providers.Ro
return roleMapping, nil return roleMapping, nil
} }
// ValidateProviderConfig validates a provider configuration
func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error {
if config == nil {
return fmt.Errorf("provider config cannot be nil")
}
if config.Name == "" {
return fmt.Errorf("provider name cannot be empty")
}
if config.Type == "" {
return fmt.Errorf("provider type cannot be empty")
}
if config.Config == nil {
return fmt.Errorf("provider config cannot be nil")
}
// Type-specific validation
switch config.Type {
case "oidc":
return f.validateOIDCConfig(config.Config)
case "ldap":
return f.validateLDAPConfig(config.Config)
case "saml":
return f.validateSAMLConfig(config.Config)
default:
return fmt.Errorf("unsupported provider type: %s", config.Type)
}
}
// validateOIDCConfig validates OIDC provider configuration
func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error {
if _, ok := config[ConfigFieldIssuer]; !ok {
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer)
}
if _, ok := config[ConfigFieldClientID]; !ok {
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID)
}
return nil
}
// validateLDAPConfig validates LDAP provider configuration
func (f *ProviderFactory) validateLDAPConfig(config map[string]interface{}) error {
if _, ok := config["server"]; !ok {
return fmt.Errorf("LDAP provider requires 'server' field")
}
if _, ok := config["baseDN"]; !ok {
return fmt.Errorf("LDAP provider requires 'baseDN' field")
}
return nil
}
// validateSAMLConfig validates SAML provider configuration
func (f *ProviderFactory) validateSAMLConfig(config map[string]interface{}) error {
// TODO: Implement when SAML provider is available
return nil
}
// GetSupportedProviderTypes returns list of supported provider types
func (f *ProviderFactory) GetSupportedProviderTypes() []string {
return []string{ProviderTypeOIDC}
}

View File

@@ -1,312 +0,0 @@
package sts
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProviderFactory_CreateOIDCProvider(t *testing.T) {
factory := NewProviderFactory()
config := &ProviderConfig{
Name: "test-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"clientSecret": "test-secret",
"jwksUri": "https://test-issuer.com/.well-known/jwks.json",
"scopes": []string{"openid", "profile", "email"},
},
}
provider, err := factory.CreateProvider(config)
require.NoError(t, err)
assert.NotNil(t, provider)
assert.Equal(t, "test-oidc", provider.Name())
}
// Note: Mock provider tests removed - mock providers are now test-only
// and not available through the production ProviderFactory
func TestProviderFactory_DisabledProvider(t *testing.T) {
factory := NewProviderFactory()
config := &ProviderConfig{
Name: "disabled-provider",
Type: "oidc",
Enabled: false,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
},
}
provider, err := factory.CreateProvider(config)
require.NoError(t, err)
assert.Nil(t, provider) // Should return nil for disabled providers
}
func TestProviderFactory_InvalidProviderType(t *testing.T) {
factory := NewProviderFactory()
config := &ProviderConfig{
Name: "invalid-provider",
Type: "unsupported-type",
Enabled: true,
Config: map[string]interface{}{},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "unsupported provider type")
}
func TestProviderFactory_LoadMultipleProviders(t *testing.T) {
factory := NewProviderFactory()
configs := []*ProviderConfig{
{
Name: "oidc-provider",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://oidc-issuer.com",
"clientId": "oidc-client",
},
},
{
Name: "disabled-provider",
Type: "oidc",
Enabled: false,
Config: map[string]interface{}{
"issuer": "https://disabled-issuer.com",
"clientId": "disabled-client",
},
},
}
providers, err := factory.LoadProvidersFromConfig(configs)
require.NoError(t, err)
assert.Len(t, providers, 1) // Only enabled providers should be loaded
assert.Contains(t, providers, "oidc-provider")
assert.NotContains(t, providers, "disabled-provider")
}
func TestProviderFactory_ValidateOIDCConfig(t *testing.T) {
factory := NewProviderFactory()
t.Run("valid config", func(t *testing.T) {
config := &ProviderConfig{
Name: "valid-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://valid-issuer.com",
"clientId": "valid-client",
},
}
err := factory.ValidateProviderConfig(config)
assert.NoError(t, err)
})
t.Run("missing issuer", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"clientId": "valid-client",
},
}
err := factory.ValidateProviderConfig(config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "issuer")
})
t.Run("missing clientId", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://valid-issuer.com",
},
}
err := factory.ValidateProviderConfig(config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "clientId")
})
}
func TestProviderFactory_ConvertToStringSlice(t *testing.T) {
factory := NewProviderFactory()
t.Run("string slice", func(t *testing.T) {
input := []string{"a", "b", "c"}
result, err := factory.convertToStringSlice(input)
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, result)
})
t.Run("interface slice", func(t *testing.T) {
input := []interface{}{"a", "b", "c"}
result, err := factory.convertToStringSlice(input)
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, result)
})
t.Run("invalid type", func(t *testing.T) {
input := "not-a-slice"
result, err := factory.convertToStringSlice(input)
assert.Error(t, err)
assert.Nil(t, result)
})
}
func TestProviderFactory_ConfigConversionErrors(t *testing.T) {
factory := NewProviderFactory()
t.Run("invalid scopes type", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-scopes",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"scopes": "invalid-not-array", // Should be array
},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "failed to convert scopes")
})
t.Run("invalid claimsMapping type", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-claims",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"claimsMapping": "invalid-not-map", // Should be map
},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "failed to convert claimsMapping")
})
t.Run("invalid roleMapping type", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-roles",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"roleMapping": "invalid-not-map", // Should be map
},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "failed to convert roleMapping")
})
}
func TestProviderFactory_ConvertToStringMap(t *testing.T) {
factory := NewProviderFactory()
t.Run("string map", func(t *testing.T) {
input := map[string]string{"key1": "value1", "key2": "value2"}
result, err := factory.convertToStringMap(input)
require.NoError(t, err)
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
})
t.Run("interface map", func(t *testing.T) {
input := map[string]interface{}{"key1": "value1", "key2": "value2"}
result, err := factory.convertToStringMap(input)
require.NoError(t, err)
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
})
t.Run("invalid type", func(t *testing.T) {
input := "not-a-map"
result, err := factory.convertToStringMap(input)
assert.Error(t, err)
assert.Nil(t, result)
})
}
func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) {
factory := NewProviderFactory()
supportedTypes := factory.GetSupportedProviderTypes()
assert.Contains(t, supportedTypes, "oidc")
assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production
}
func TestSTSService_LoadProvidersFromConfig(t *testing.T) {
stsConfig := &STSConfig{
TokenDuration: FlexibleDuration{3600 * time.Second},
MaxSessionLength: FlexibleDuration{43200 * time.Second},
Issuer: "test-issuer",
SigningKey: []byte("test-signing-key-32-characters-long"),
Providers: []*ProviderConfig{
{
Name: "test-provider",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
},
},
},
}
stsService := NewSTSService()
err := stsService.Initialize(stsConfig)
require.NoError(t, err)
// Check that provider was loaded
assert.Len(t, stsService.providers, 1)
assert.Contains(t, stsService.providers, "test-provider")
assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name())
}
func TestSTSService_NoProvidersConfig(t *testing.T) {
stsConfig := &STSConfig{
TokenDuration: FlexibleDuration{3600 * time.Second},
MaxSessionLength: FlexibleDuration{43200 * time.Second},
Issuer: "test-issuer",
SigningKey: []byte("test-signing-key-32-characters-long"),
// No providers configured
}
stsService := NewSTSService()
err := stsService.Initialize(stsConfig)
require.NoError(t, err)
// Should initialize successfully with no providers
assert.Len(t, stsService.providers, 0)
}

View File

@@ -1,193 +0,0 @@
package sts
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
// with specific issuer claims can only be validated by the provider registered for that issuer
func TestSecurityIssuerToProviderMapping(t *testing.T) {
ctx := context.Background()
// Create STS service with two mock providers
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Create two mock providers with different issuers
providerA := &MockIdentityProviderWithIssuer{
name: "provider-a",
issuer: "https://provider-a.com",
validTokens: map[string]bool{
"token-for-provider-a": true,
},
}
providerB := &MockIdentityProviderWithIssuer{
name: "provider-b",
issuer: "https://provider-b.com",
validTokens: map[string]bool{
"token-for-provider-b": true,
},
}
// Register both providers
err = service.RegisterProvider(providerA)
require.NoError(t, err)
err = service.RegisterProvider(providerB)
require.NoError(t, err)
// Create JWT tokens with specific issuer claims
tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
// This should succeed - token has issuer A and provider A is registered
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-a", provider.Name())
})
t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
// This should succeed - token has issuer B and provider B is registered
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-b", provider.Name())
})
t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
// Create token with unregistered issuer
tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
// This should fail - no provider registered for this issuer
identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
assert.Error(t, err)
assert.Nil(t, identity)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
})
t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) {
// Non-JWT tokens should be rejected - no fallback mechanism exists for security
identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
assert.Error(t, err)
assert.Nil(t, identity)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "web identity token must be a valid JWT token")
})
}
// createTestJWT creates a test JWT token with the specified issuer and subject
func createTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
type MockIdentityProviderWithIssuer struct {
name string
issuer string
validTokens map[string]bool
}
func (m *MockIdentityProviderWithIssuer) Name() string {
return m.name
}
func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
return m.issuer
}
func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// For JWT tokens, parse and validate the token format
if len(token) > 50 && strings.Contains(token, ".") {
// This looks like a JWT - parse it to get the subject
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("invalid JWT token")
}
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid claims")
}
issuer, _ := claims["iss"].(string)
subject, _ := claims["sub"].(string)
// Verify the issuer matches what we expect
if issuer != m.issuer {
return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
}
return &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
// For non-JWT tokens, check our simple token list
if m.validTokens[token] {
return &providers.ExternalIdentity{
UserID: "test-user",
Email: "test@" + m.name + ".com",
Provider: m.name,
}, nil
}
return nil, fmt.Errorf("invalid token")
}
func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Email: userID + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if m.validTokens[token] {
return &providers.TokenClaims{
Subject: "test-user",
Issuer: m.issuer,
}, nil
}
return nil, fmt.Errorf("invalid token")
}

View File

@@ -1,168 +0,0 @@
package sts
import (
"context"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// createSessionPolicyTestJWT creates a test JWT token for session policy tests
func createSessionPolicyTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// TestAssumeRoleWithWebIdentity_SessionPolicy verifies inline session policies are preserved in tokens.
func TestAssumeRoleWithWebIdentity_SessionPolicy(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::example-bucket/*"}]}`
testToken := createSessionPolicyTestJWT(t, "test-issuer", "test-user")
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: testToken,
RoleSessionName: "test-session",
Policy: &sessionPolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
normalized, err := NormalizeSessionPolicy(sessionPolicy)
require.NoError(t, err)
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
t.Run("should_succeed_without_session_policy", func(t *testing.T) {
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
assert.Empty(t, sessionInfo.SessionPolicy)
})
}
// Test edge case scenarios for the Policy field handling
func TestAssumeRoleWithWebIdentity_SessionPolicy_EdgeCases(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
t.Run("malformed_json_policy_rejected", func(t *testing.T) {
malformedPolicy := `{"Version": "2012-10-17", "Statement": [` // Incomplete JSON
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
Policy: &malformedPolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
assert.Error(t, err)
assert.Nil(t, response)
assert.Contains(t, err.Error(), "invalid session policy JSON")
})
t.Run("invalid_policy_document_rejected", func(t *testing.T) {
invalidPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow"}]}`
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
Policy: &invalidPolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
assert.Error(t, err)
assert.Nil(t, response)
assert.Contains(t, err.Error(), "invalid session policy document")
})
t.Run("whitespace_policy_ignored", func(t *testing.T) {
whitespacePolicy := " \t\n "
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
Policy: &whitespacePolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
assert.Empty(t, sessionInfo.SessionPolicy)
})
}
// TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation verifies that the struct field exists and is optional.
func TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation(t *testing.T) {
request := &AssumeRoleWithWebIdentityRequest{}
assert.IsType(t, (*string)(nil), request.Policy,
"Policy field should be *string type for optional JSON policy")
assert.Nil(t, request.Policy,
"Policy field should default to nil (no session policy)")
policyValue := `{"Version": "2012-10-17"}`
request.Policy = &policyValue
assert.NotNil(t, request.Policy, "Should be able to assign policy value")
assert.Equal(t, policyValue, *request.Policy, "Policy value should be preserved")
}
// TestAssumeRoleWithCredentials_SessionPolicy verifies session policy support for credentials-based flow.
func TestAssumeRoleWithCredentials_SessionPolicy(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"filer:CreateEntry","Resource":"arn:aws:filer::path/user-docs/*"}]}`
request := &AssumeRoleWithCredentialsRequest{
RoleArn: "arn:aws:iam::role/TestRole",
Username: "testuser",
Password: "testpass",
RoleSessionName: "test-session",
ProviderName: "test-ldap",
Policy: &sessionPolicy,
}
response, err := service.AssumeRoleWithCredentials(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
normalized, err := NormalizeSessionPolicy(sessionPolicy)
require.NoError(t, err)
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
}

View File

@@ -879,21 +879,6 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir
return duration return duration
} }
// extractSessionIdFromToken extracts session ID from JWT session token
func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
// Validate JWT and extract session claims
claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
if err != nil {
// For test compatibility, also handle direct session IDs
if len(sessionToken) == 32 { // Typical session ID length
return sessionToken
}
return ""
}
return claims.SessionId
}
// validateAssumeRoleWithCredentialsRequest validates the credentials request parameters // validateAssumeRoleWithCredentialsRequest validates the credentials request parameters
func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error { func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error {
if request.RoleArn == "" { if request.RoleArn == "" {

View File

@@ -1,778 +0,0 @@
package sts
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// createSTSTestJWT creates a test JWT token for STS service tests
func createSTSTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// TestSTSServiceInitialization tests STS service initialization
func TestSTSServiceInitialization(t *testing.T) {
tests := []struct {
name string
config *STSConfig
wantErr bool
}{
{
name: "valid config",
config: &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "seaweedfs-sts",
SigningKey: []byte("test-signing-key"),
},
wantErr: false,
},
{
name: "missing signing key",
config: &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
Issuer: "seaweedfs-sts",
},
wantErr: true,
},
{
name: "invalid token duration",
config: &STSConfig{
TokenDuration: FlexibleDuration{-time.Hour},
Issuer: "seaweedfs-sts",
SigningKey: []byte("test-key"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewSTSService()
err := service.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, service.IsInitialized())
// Verify defaults if applicable
if tt.config.Issuer == "" {
assert.Equal(t, DefaultIssuer, service.Config.Issuer)
}
if tt.config.TokenDuration.Duration == 0 {
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, service.Config.TokenDuration.Duration)
}
}
})
}
}
func TestSTSServiceDefaults(t *testing.T) {
service := NewSTSService()
config := &STSConfig{
SigningKey: []byte("test-signing-key"),
// Missing duration and issuer
}
err := service.Initialize(config)
assert.NoError(t, err)
assert.Equal(t, DefaultIssuer, config.Issuer)
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, config.TokenDuration.Duration)
assert.Equal(t, time.Duration(DefaultMaxSessionLength)*time.Second, config.MaxSessionLength.Duration)
}
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
func TestAssumeRoleWithWebIdentity(t *testing.T) {
service := setupTestSTSService(t)
tests := []struct {
name string
roleArn string
webIdentityToken string
sessionName string
durationSeconds *int64
wantErr bool
expectedSubject string
}{
{
name: "successful role assumption",
roleArn: "arn:aws:iam::role/TestRole",
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
sessionName: "test-session",
durationSeconds: nil, // Use default
wantErr: false,
expectedSubject: "test-user-id",
},
{
name: "invalid web identity token",
roleArn: "arn:aws:iam::role/TestRole",
webIdentityToken: "invalid-token",
sessionName: "test-session",
wantErr: true,
},
{
name: "non-existent role",
roleArn: "arn:aws:iam::role/NonExistentRole",
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
sessionName: "test-session",
wantErr: true,
},
{
name: "custom session duration",
roleArn: "arn:aws:iam::role/TestRole",
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
sessionName: "test-session",
durationSeconds: int64Ptr(7200), // 2 hours
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: tt.roleArn,
WebIdentityToken: tt.webIdentityToken,
RoleSessionName: tt.sessionName,
DurationSeconds: tt.durationSeconds,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, response)
} else {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotNil(t, response.Credentials)
assert.NotNil(t, response.AssumedRoleUser)
// Verify credentials
creds := response.Credentials
assert.NotEmpty(t, creds.AccessKeyId)
assert.NotEmpty(t, creds.SecretAccessKey)
assert.NotEmpty(t, creds.SessionToken)
assert.True(t, creds.Expiration.After(time.Now()))
// Verify assumed role user
user := response.AssumedRoleUser
assert.Equal(t, tt.roleArn, user.AssumedRoleId)
assert.Contains(t, user.Arn, tt.sessionName)
if tt.expectedSubject != "" {
assert.Equal(t, tt.expectedSubject, user.Subject)
}
}
})
}
}
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
func TestAssumeRoleWithLDAP(t *testing.T) {
service := setupTestSTSService(t)
tests := []struct {
name string
roleArn string
username string
password string
sessionName string
wantErr bool
}{
{
name: "successful LDAP role assumption",
roleArn: "arn:aws:iam::role/LDAPRole",
username: "testuser",
password: "testpass",
sessionName: "ldap-session",
wantErr: false,
},
{
name: "invalid LDAP credentials",
roleArn: "arn:aws:iam::role/LDAPRole",
username: "testuser",
password: "wrongpass",
sessionName: "ldap-session",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
request := &AssumeRoleWithCredentialsRequest{
RoleArn: tt.roleArn,
Username: tt.username,
Password: tt.password,
RoleSessionName: tt.sessionName,
ProviderName: "test-ldap",
}
response, err := service.AssumeRoleWithCredentials(ctx, request)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, response)
} else {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotNil(t, response.Credentials)
}
})
}
}
// TestSessionTokenValidation tests session token validation
func TestSessionTokenValidation(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
// First, create a session
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionToken := response.Credentials.SessionToken
tests := []struct {
name string
token string
wantErr bool
}{
{
name: "valid session token",
token: sessionToken,
wantErr: false,
},
{
name: "invalid session token",
token: "invalid-session-token",
wantErr: true,
},
{
name: "empty session token",
token: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session, err := service.ValidateSessionToken(ctx, tt.token)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, session)
} else {
assert.NoError(t, err)
assert.NotNil(t, session)
assert.Equal(t, "test-session", session.SessionName)
assert.Equal(t, "arn:aws:iam::role/TestRole", session.RoleArn)
}
})
}
}
// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
func TestSessionTokenPersistence(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
// Create a session first
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token is valid initially
session, err := service.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err)
assert.NotNil(t, session)
assert.Equal(t, "test-session", session.SessionName)
// In a stateless JWT system, tokens remain valid throughout their lifetime
// Multiple validations should all succeed as long as the token hasn't expired
session2, err := service.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err, "Token should remain valid in stateless system")
assert.NotNil(t, session2, "Session should be returned from JWT token")
assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
}
// Helper functions
func setupTestSTSService(t *testing.T) *STSService {
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator (required for STS testing)
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Register test providers
mockOIDCProvider := &MockIdentityProvider{
name: "test-oidc",
validTokens: map[string]*providers.TokenClaims{
createSTSTestJWT(t, "test-issuer", "test-user"): {
Subject: "test-user-id",
Issuer: "test-issuer",
Claims: map[string]interface{}{
"email": "test@example.com",
"name": "Test User",
},
},
},
}
mockLDAPProvider := &MockIdentityProvider{
name: "test-ldap",
validCredentials: map[string]string{
"testuser": "testpass",
},
}
service.RegisterProvider(mockOIDCProvider)
service.RegisterProvider(mockLDAPProvider)
return service
}
func int64Ptr(v int64) *int64 {
return &v
}
// Mock identity provider for testing
type MockIdentityProvider struct {
name string
validTokens map[string]*providers.TokenClaims
validCredentials map[string]string
}
func (m *MockIdentityProvider) Name() string {
return m.name
}
func (m *MockIdentityProvider) GetIssuer() string {
return "test-issuer" // This matches the issuer in the token claims
}
func (m *MockIdentityProvider) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// First try to parse as JWT token
if len(token) > 20 && strings.Count(token, ".") >= 2 {
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err == nil {
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
issuer, _ := claims["iss"].(string)
subject, _ := claims["sub"].(string)
// Verify the issuer matches what we expect
if issuer == "test-issuer" && subject != "" {
return &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@test-domain.com",
DisplayName: "Test User " + subject,
Provider: m.name,
}, nil
}
}
}
}
// Handle legacy OIDC tokens (for backwards compatibility)
if claims, exists := m.validTokens[token]; exists {
email, _ := claims.GetClaimString("email")
name, _ := claims.GetClaimString("name")
return &providers.ExternalIdentity{
UserID: claims.Subject,
Email: email,
DisplayName: name,
Provider: m.name,
}, nil
}
// Handle LDAP credentials (username:password format)
if m.validCredentials != nil {
parts := strings.Split(token, ":")
if len(parts) == 2 {
username, password := parts[0], parts[1]
if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
return &providers.ExternalIdentity{
UserID: username,
Email: username + "@" + m.name + ".com",
DisplayName: "Test User " + username,
Provider: m.name,
}, nil
}
}
}
return nil, fmt.Errorf("unknown test token: %s", token)
}
func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Email: userID + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if claims, exists := m.validTokens[token]; exists {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim
func TestSessionDurationCappedByTokenExpiration(t *testing.T) {
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
tests := []struct {
name string
durationSeconds *int64
tokenExpiration *time.Time
expectedMaxSeconds int64
description string
}{
{
name: "no token expiration - use default duration",
durationSeconds: nil,
tokenExpiration: nil,
expectedMaxSeconds: 3600, // 1 hour default
description: "When no token expiration is set, use the configured default duration",
},
{
name: "token expires before default duration",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)),
expectedMaxSeconds: 30 * 60, // 30 minutes
description: "When token expires in 30 min, session should be capped at 30 min",
},
{
name: "token expires after default duration - use default",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)),
expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry
description: "When token expires after default duration, use the default duration",
},
{
name: "requested duration shorter than token expiry",
durationSeconds: int64Ptr(1800), // 30 min requested
tokenExpiration: timePtr(time.Now().Add(time.Hour)),
expectedMaxSeconds: 1800, // 30 minutes as requested
description: "When requested duration is shorter than token expiry, use requested duration",
},
{
name: "requested duration longer than token expiry - cap at token expiry",
durationSeconds: int64Ptr(3600), // 1 hour requested
tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)),
expectedMaxSeconds: 15 * 60, // Capped at 15 minutes
description: "When requested duration exceeds token expiry, cap at token expiry",
},
{
name: "GitLab CI short-lived token scenario",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)),
expectedMaxSeconds: 5 * 60, // 5 minutes
description: "GitLab CI job with 5 minute timeout should result in 5 minute session",
},
{
name: "already expired token - defense in depth",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago
expectedMaxSeconds: 60, // 1 minute minimum
description: "Already expired token should result in minimal 1 minute session",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration)
// Allow 5 second tolerance for time calculations
maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second
minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second
assert.GreaterOrEqual(t, duration, minExpected,
"%s: duration %v should be >= %v", tt.description, duration, minExpected)
assert.LessOrEqual(t, duration, maxExpected,
"%s: duration %v should be <= %v", tt.description, duration, maxExpected)
})
}
}
// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped
func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) {
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Create a mock provider that returns tokens with short expiration
shortLivedTokenExpiration := time.Now().Add(10 * time.Minute)
mockProvider := &MockIdentityProviderWithExpiration{
name: "short-lived-issuer",
tokenExpiration: &shortLivedTokenExpiration,
}
service.RegisterProvider(mockProvider)
ctx := context.Background()
// Create a JWT token with short expiration
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": "short-lived-issuer",
"sub": "test-user",
"aud": "test-client",
"exp": shortLivedTokenExpiration.Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: tokenString,
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
// Verify the session expires at or before the token expiration
// Allow 5 second tolerance
assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)),
"Session expiration (%v) should not exceed token expiration (%v)",
response.Credentials.Expiration, shortLivedTokenExpiration)
}
// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration
type MockIdentityProviderWithExpiration struct {
name string
tokenExpiration *time.Time
}
func (m *MockIdentityProviderWithExpiration) Name() string {
return m.name
}
func (m *MockIdentityProviderWithExpiration) GetIssuer() string {
return m.name
}
func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// Parse the token to get subject
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid claims")
}
subject, _ := claims["sub"].(string)
identity := &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@example.com",
DisplayName: "Test User",
Provider: m.name,
TokenExpiration: m.tokenExpiration,
}
return identity, nil
}
func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Provider: m.name,
}, nil
}
func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
claims := &providers.TokenClaims{
Subject: "test-user",
Issuer: m.name,
}
if m.tokenExpiration != nil {
claims.ExpiresAt = *m.tokenExpiration
}
return claims, nil
}
func timePtr(t time.Time) *time.Time {
return &t
}
// TestAssumeRoleWithWebIdentity_PreservesAttributes tests that attributes from the identity provider
// are correctly propagated to the session token's request context
func TestAssumeRoleWithWebIdentity_PreservesAttributes(t *testing.T) {
service := setupTestSTSService(t)
// Create a mock provider that returns a user with attributes
mockProvider := &MockIdentityProviderWithAttributes{
name: "attr-provider",
attributes: map[string]string{
"preferred_username": "my-user",
"department": "engineering",
"project": "seaweedfs",
},
}
service.RegisterProvider(mockProvider)
// Create a valid JWT token for the provider
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": "attr-provider",
"sub": "test-user-id",
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
ctx := context.Background()
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: tokenString,
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
// Validate the session token to check claims
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
// Check that attributes are present in RequestContext
require.NotNil(t, sessionInfo.RequestContext, "RequestContext should not be nil")
assert.Equal(t, "my-user", sessionInfo.RequestContext["preferred_username"])
assert.Equal(t, "engineering", sessionInfo.RequestContext["department"])
assert.Equal(t, "seaweedfs", sessionInfo.RequestContext["project"])
// Check standard claims are also present
assert.Equal(t, "test-user-id", sessionInfo.RequestContext["sub"])
assert.Equal(t, "test@example.com", sessionInfo.RequestContext["email"])
assert.Equal(t, "Test User", sessionInfo.RequestContext["name"])
}
// MockIdentityProviderWithAttributes is a mock provider that returns configured attributes
type MockIdentityProviderWithAttributes struct {
name string
attributes map[string]string
}
func (m *MockIdentityProviderWithAttributes) Name() string {
return m.name
}
func (m *MockIdentityProviderWithAttributes) GetIssuer() string {
return m.name
}
func (m *MockIdentityProviderWithAttributes) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithAttributes) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: "test-user-id",
Email: "test@example.com",
DisplayName: "Test User",
Provider: m.name,
Attributes: m.attributes,
}, nil
}
func (m *MockIdentityProviderWithAttributes) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return nil, nil
}
func (m *MockIdentityProviderWithAttributes) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
return &providers.TokenClaims{
Subject: "test-user-id",
Issuer: m.name,
}, nil
}

View File

@@ -1,53 +1,4 @@
package sts package sts
import (
"context"
"fmt"
"strings"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// MockTrustPolicyValidator is a simple mock for testing STS functionality // MockTrustPolicyValidator is a simple mock for testing STS functionality
type MockTrustPolicyValidator struct{} type MockTrustPolicyValidator struct{}
// ValidateTrustPolicyForWebIdentity allows valid JWT test tokens for STS testing
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForWebIdentity(ctx context.Context, roleArn string, webIdentityToken string, durationSeconds *int64) error {
// Reject non-existent roles for testing
if strings.Contains(roleArn, "NonExistentRole") {
return fmt.Errorf("trust policy validation failed: role does not exist")
}
// For STS unit tests, allow JWT tokens that look valid (contain dots for JWT structure)
// In real implementation, this would validate against actual trust policies
if len(webIdentityToken) > 20 && strings.Count(webIdentityToken, ".") >= 2 {
// This appears to be a JWT token - allow it for testing
return nil
}
// Legacy support for specific test tokens during migration
if webIdentityToken == "valid_test_token" || webIdentityToken == "valid-oidc-token" {
return nil
}
// Reject invalid tokens
if webIdentityToken == "invalid_token" || webIdentityToken == "expired_token" || webIdentityToken == "invalid-token" {
return fmt.Errorf("trust policy denies token")
}
return nil
}
// ValidateTrustPolicyForCredentials allows valid test identities for STS testing
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForCredentials(ctx context.Context, roleArn string, identity *providers.ExternalIdentity) error {
// Reject non-existent roles for testing
if strings.Contains(roleArn, "NonExistentRole") {
return fmt.Errorf("trust policy validation failed: role does not exist")
}
// For STS unit tests, allow test identities
if identity != nil && identity.UserID != "" {
return nil
}
return fmt.Errorf("invalid identity for role assumption")
}

View File

@@ -1,29 +0,0 @@
package images
import (
"bytes"
"io"
"path/filepath"
"strings"
)
/*
* Preprocess image files on client side.
* 1. possibly adjust the orientation
* 2. resize the image to a width or height limit
* 3. remove the exif data
* Call this function on any file uploaded to SeaweedFS
*
*/
func MaybePreprocessImage(filename string, data []byte, width, height int) (resized io.ReadSeeker, w int, h int) {
ext := filepath.Ext(filename)
ext = strings.ToLower(ext)
switch ext {
case ".png", ".gif":
return Resized(ext, bytes.NewReader(data), width, height, "")
case ".jpg", ".jpeg":
data = FixJpgOrientation(data)
return Resized(ext, bytes.NewReader(data), width, height, "")
}
return bytes.NewReader(data), 0, 0
}

View File

@@ -290,15 +290,6 @@ func (loader *ConfigLoader) ValidateConfiguration() error {
return nil return nil
} }
// LoadKMSFromFilerToml is a convenience function to load KMS configuration from filer.toml
func LoadKMSFromFilerToml(v ViperConfig) error {
loader := NewConfigLoader(v)
if err := loader.LoadConfigurations(); err != nil {
return err
}
return loader.ValidateConfiguration()
}
// LoadKMSFromConfig loads KMS configuration directly from parsed JSON data // LoadKMSFromConfig loads KMS configuration directly from parsed JSON data
func LoadKMSFromConfig(kmsConfig interface{}) error { func LoadKMSFromConfig(kmsConfig interface{}) error {
kmsMap, ok := kmsConfig.(map[string]interface{}) kmsMap, ok := kmsConfig.(map[string]interface{})
@@ -415,12 +406,3 @@ func getIntFromConfig(config map[string]interface{}, key string, defaultValue in
} }
return defaultValue return defaultValue
} }
func getStringFromConfig(config map[string]interface{}, key string, defaultValue string) string {
if value, exists := config[key]; exists {
if stringValue, ok := value.(string); ok {
return stringValue
}
}
return defaultValue
}

View File

@@ -147,13 +147,6 @@ func (fh *FileHandle) ReleaseHandle() {
} }
} }
func lessThan(a, b *filer_pb.FileChunk) bool {
if a.ModifiedTsNs == b.ModifiedTsNs {
return a.Fid.FileKey < b.Fid.FileKey
}
return a.ModifiedTsNs < b.ModifiedTsNs
}
// getCumulativeOffsets returns cached cumulative offsets for chunks, computing them if necessary // getCumulativeOffsets returns cached cumulative offsets for chunks, computing them if necessary
func (fh *FileHandle) getCumulativeOffsets(chunks []*filer_pb.FileChunk) []int64 { func (fh *FileHandle) getCumulativeOffsets(chunks []*filer_pb.FileChunk) []int64 {
fh.chunkCacheLock.RLock() fh.chunkCacheLock.RLock()

View File

@@ -21,9 +21,3 @@ func min(x, y int64) int64 {
} }
return y return y
} }
func minInt(x, y int) int {
if x < y {
return x
}
return y
}

View File

@@ -119,13 +119,6 @@ func (c *RDMAMountClient) lookupVolumeLocationByFileID(ctx context.Context, file
return bestAddress, nil return bestAddress, nil
} }
// lookupVolumeLocation finds the best volume server for a given volume ID (legacy method)
func (c *RDMAMountClient) lookupVolumeLocation(ctx context.Context, volumeID uint32, needleID uint64, cookie uint32) (string, error) {
// Create a file ID for lookup (format: volumeId,needleId,cookie)
fileID := fmt.Sprintf("%d,%x,%d", volumeID, needleID, cookie)
return c.lookupVolumeLocationByFileID(ctx, fileID)
}
// healthCheck verifies that the RDMA sidecar is available and functioning // healthCheck verifies that the RDMA sidecar is available and functioning
func (c *RDMAMountClient) healthCheck() error { func (c *RDMAMountClient) healthCheck() error {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout) ctx, cancel := context.WithTimeout(context.Background(), c.timeout)

View File

@@ -117,11 +117,6 @@ func GetBrokerErrorInfo(code int32) BrokerErrorInfo {
} }
} }
// GetKafkaErrorCode returns the corresponding Kafka protocol error code for a broker error
func GetKafkaErrorCode(brokerErrorCode int32) int16 {
return GetBrokerErrorInfo(brokerErrorCode).KafkaCode
}
// CreateBrokerError creates a structured broker error with both error code and message // CreateBrokerError creates a structured broker error with both error code and message
func CreateBrokerError(code int32, message string) (int32, string) { func CreateBrokerError(code int32, message string) (int32, string) {
info := GetBrokerErrorInfo(code) info := GetBrokerErrorInfo(code)

View File

@@ -1,351 +0,0 @@
package broker
import (
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/mq/topic"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
func createTestTopic() topic.Topic {
return topic.Topic{
Namespace: "test",
Name: "offset-test",
}
}
func createTestPartition() topic.Partition {
return topic.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
}
func TestBrokerOffsetManager_AssignOffset(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Test sequential offset assignment
for i := int64(0); i < 10; i++ {
assignedOffset, err := manager.AssignOffset(testTopic, testPartition)
if err != nil {
t.Fatalf("Failed to assign offset %d: %v", i, err)
}
if assignedOffset != i {
t.Errorf("Expected offset %d, got %d", i, assignedOffset)
}
}
}
func TestBrokerOffsetManager_AssignBatchOffsets(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Assign batch of offsets
baseOffset, lastOffset, err := manager.AssignBatchOffsets(testTopic, testPartition, 5)
if err != nil {
t.Fatalf("Failed to assign batch offsets: %v", err)
}
if baseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", baseOffset)
}
if lastOffset != 4 {
t.Errorf("Expected last offset 4, got %d", lastOffset)
}
// Assign another batch
baseOffset2, lastOffset2, err := manager.AssignBatchOffsets(testTopic, testPartition, 3)
if err != nil {
t.Fatalf("Failed to assign second batch offsets: %v", err)
}
if baseOffset2 != 5 {
t.Errorf("Expected base offset 5, got %d", baseOffset2)
}
if lastOffset2 != 7 {
t.Errorf("Expected last offset 7, got %d", lastOffset2)
}
}
func TestBrokerOffsetManager_GetHighWaterMark(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Initially should be 0
hwm, err := manager.GetHighWaterMark(testTopic, testPartition)
if err != nil {
t.Fatalf("Failed to get initial high water mark: %v", err)
}
if hwm != 0 {
t.Errorf("Expected initial high water mark 0, got %d", hwm)
}
// Assign some offsets
manager.AssignBatchOffsets(testTopic, testPartition, 10)
// High water mark should be updated
hwm, err = manager.GetHighWaterMark(testTopic, testPartition)
if err != nil {
t.Fatalf("Failed to get high water mark after assignment: %v", err)
}
if hwm != 10 {
t.Errorf("Expected high water mark 10, got %d", hwm)
}
}
func TestBrokerOffsetManager_CreateSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Assign some offsets first
manager.AssignBatchOffsets(testTopic, testPartition, 5)
// Create subscription
sub, err := manager.CreateSubscription(
"test-sub",
testTopic,
testPartition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
if sub.ID != "test-sub" {
t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID)
}
if sub.StartOffset != 0 {
t.Errorf("Expected start offset 0, got %d", sub.StartOffset)
}
}
func TestBrokerOffsetManager_GetPartitionOffsetInfo(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Test empty partition
info, err := manager.GetPartitionOffsetInfo(testTopic, testPartition)
if err != nil {
t.Fatalf("Failed to get partition offset info: %v", err)
}
if info.EarliestOffset != 0 {
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
}
if info.LatestOffset != -1 {
t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset)
}
// Assign offsets and test again
manager.AssignBatchOffsets(testTopic, testPartition, 5)
info, err = manager.GetPartitionOffsetInfo(testTopic, testPartition)
if err != nil {
t.Fatalf("Failed to get partition offset info after assignment: %v", err)
}
if info.LatestOffset != 4 {
t.Errorf("Expected latest offset 4, got %d", info.LatestOffset)
}
if info.HighWaterMark != 5 {
t.Errorf("Expected high water mark 5, got %d", info.HighWaterMark)
}
}
func TestBrokerOffsetManager_MultiplePartitions(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
// Create different partitions
partition1 := topic.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
partition2 := topic.Partition{
RingSize: 1024,
RangeStart: 32,
RangeStop: 63,
UnixTimeNs: time.Now().UnixNano(),
}
// Assign offsets to different partitions
assignedOffset1, err := manager.AssignOffset(testTopic, partition1)
if err != nil {
t.Fatalf("Failed to assign offset to partition1: %v", err)
}
assignedOffset2, err := manager.AssignOffset(testTopic, partition2)
if err != nil {
t.Fatalf("Failed to assign offset to partition2: %v", err)
}
// Both should start at 0
if assignedOffset1 != 0 {
t.Errorf("Expected offset 0 for partition1, got %d", assignedOffset1)
}
if assignedOffset2 != 0 {
t.Errorf("Expected offset 0 for partition2, got %d", assignedOffset2)
}
// Assign more offsets to partition1
assignedOffset1_2, err := manager.AssignOffset(testTopic, partition1)
if err != nil {
t.Fatalf("Failed to assign second offset to partition1: %v", err)
}
if assignedOffset1_2 != 1 {
t.Errorf("Expected offset 1 for partition1, got %d", assignedOffset1_2)
}
// Partition2 should still be at 0 for next assignment
assignedOffset2_2, err := manager.AssignOffset(testTopic, partition2)
if err != nil {
t.Fatalf("Failed to assign second offset to partition2: %v", err)
}
if assignedOffset2_2 != 1 {
t.Errorf("Expected offset 1 for partition2, got %d", assignedOffset2_2)
}
}
func TestOffsetAwarePublisher(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Create a mock local partition (simplified for testing)
localPartition := &topic.LocalPartition{}
// Create offset assignment function
assignOffsetFn := func() (int64, error) {
return manager.AssignOffset(testTopic, testPartition)
}
// Create offset-aware publisher
publisher := topic.NewOffsetAwarePublisher(localPartition, assignOffsetFn)
if publisher.GetPartition() != localPartition {
t.Error("Publisher should return the correct partition")
}
// Test would require more setup to actually publish messages
// This tests the basic structure
}
func TestBrokerOffsetManager_GetOffsetMetrics(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Initial metrics
metrics := manager.GetOffsetMetrics()
if metrics.TotalOffsets != 0 {
t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets)
}
// Assign some offsets
manager.AssignBatchOffsets(testTopic, testPartition, 5)
// Create subscription
manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
// Check updated metrics
metrics = manager.GetOffsetMetrics()
if metrics.PartitionCount != 1 {
t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
}
}
func TestBrokerOffsetManager_AssignOffsetsWithResult(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Assign offsets with result
result := manager.AssignOffsetsWithResult(testTopic, testPartition, 3)
if result.Error != nil {
t.Fatalf("Expected no error, got: %v", result.Error)
}
if result.BaseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", result.BaseOffset)
}
if result.LastOffset != 2 {
t.Errorf("Expected last offset 2, got %d", result.LastOffset)
}
if result.Count != 3 {
t.Errorf("Expected count 3, got %d", result.Count)
}
if result.Topic != testTopic {
t.Error("Topic mismatch in result")
}
if result.Partition != testPartition {
t.Error("Partition mismatch in result")
}
if result.Timestamp <= 0 {
t.Error("Timestamp should be set")
}
}
func TestBrokerOffsetManager_Shutdown(t *testing.T) {
storage := NewInMemoryOffsetStorageForTesting()
manager := NewBrokerOffsetManagerWithStorage(storage)
testTopic := createTestTopic()
testPartition := createTestPartition()
// Assign some offsets and create subscriptions
manager.AssignBatchOffsets(testTopic, testPartition, 5)
manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
// Shutdown should not panic
manager.Shutdown()
// After shutdown, operations should still work (using new managers)
offset, err := manager.AssignOffset(testTopic, testPartition)
if err != nil {
t.Fatalf("Operations should still work after shutdown: %v", err)
}
// Should start from 0 again (new manager)
if offset != 0 {
t.Errorf("Expected offset 0 after shutdown, got %d", offset)
}
}

View File

@@ -203,14 +203,6 @@ func (b *MessageQueueBroker) GetDataCenter() string {
} }
func (b *MessageQueueBroker) withMasterClient(streamingMode bool, master pb.ServerAddress, fn func(client master_pb.SeaweedClient) error) error {
return pb.WithMasterClient(streamingMode, master, b.grpcDialOption, false, func(client master_pb.SeaweedClient) error {
return fn(client)
})
}
func (b *MessageQueueBroker) withBrokerClient(streamingMode bool, server pb.ServerAddress, fn func(client mq_pb.SeaweedMessagingClient) error) error { func (b *MessageQueueBroker) withBrokerClient(streamingMode bool, server pb.ServerAddress, fn func(client mq_pb.SeaweedMessagingClient) error) error {
return pb.WithBrokerGrpcClient(streamingMode, server.String(), b.grpcDialOption, func(client mq_pb.SeaweedMessagingClient) error { return pb.WithBrokerGrpcClient(streamingMode, server.String(), b.grpcDialOption, func(client mq_pb.SeaweedMessagingClient) error {

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"strings"
"time" "time"
"github.com/seaweedfs/seaweedfs/weed/filer_client" "github.com/seaweedfs/seaweedfs/weed/filer_client"
@@ -192,10 +191,6 @@ func (f *FilerStorage) getOffsetPath(group, topic string, partition int32) strin
return fmt.Sprintf("%s/offset", f.getPartitionPath(group, topic, partition)) return fmt.Sprintf("%s/offset", f.getPartitionPath(group, topic, partition))
} }
func (f *FilerStorage) getMetadataPath(group, topic string, partition int32) string {
return fmt.Sprintf("%s/metadata", f.getPartitionPath(group, topic, partition))
}
func (f *FilerStorage) writeFile(path string, data []byte) error { func (f *FilerStorage) writeFile(path string, data []byte) error {
fullPath := util.FullPath(path) fullPath := util.FullPath(path)
dir, name := fullPath.DirAndName() dir, name := fullPath.DirAndName()
@@ -311,16 +306,3 @@ func (f *FilerStorage) deleteDirectory(path string) error {
return err return err
}) })
} }
// normalizePath removes leading/trailing slashes and collapses multiple slashes
func normalizePath(path string) string {
path = strings.Trim(path, "/")
parts := strings.Split(path, "/")
normalized := []string{}
for _, part := range parts {
if part != "" {
normalized = append(normalized, part)
}
}
return "/" + strings.Join(normalized, "/")
}

View File

@@ -1,65 +0,0 @@
package consumer_offset
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Note: These tests require a running filer instance
// They are marked as integration tests and should be run with:
// go test -tags=integration
func TestFilerStorageCommitAndFetch(t *testing.T) {
t.Skip("Requires running filer - integration test")
// This will be implemented once we have test infrastructure
// Test will:
// 1. Create filer storage
// 2. Commit offset
// 3. Fetch offset
// 4. Verify values match
}
func TestFilerStoragePersistence(t *testing.T) {
t.Skip("Requires running filer - integration test")
// Test will:
// 1. Commit offset with first storage instance
// 2. Close first instance
// 3. Create new storage instance
// 4. Fetch offset and verify it persisted
}
func TestFilerStorageMultipleGroups(t *testing.T) {
t.Skip("Requires running filer - integration test")
// Test will:
// 1. Commit offsets for multiple groups
// 2. Fetch all offsets per group
// 3. Verify isolation between groups
}
func TestFilerStoragePath(t *testing.T) {
// Test path generation (doesn't require filer)
storage := &FilerStorage{}
group := "test-group"
topic := "test-topic"
partition := int32(5)
groupPath := storage.getGroupPath(group)
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group", groupPath)
topicPath := storage.getTopicPath(group, topic)
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic", topicPath)
partitionPath := storage.getPartitionPath(group, topic, partition)
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5", partitionPath)
offsetPath := storage.getOffsetPath(group, topic, partition)
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/offset", offsetPath)
metadataPath := storage.getMetadataPath(group, topic, partition)
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/metadata", metadataPath)
}

View File

@@ -278,38 +278,3 @@ func (h *SeaweedMQHandler) checkTopicInFiler(topicName string) bool {
return exists return exists
} }
// listTopicsFromFiler lists all topics from the filer
func (h *SeaweedMQHandler) listTopicsFromFiler() []string {
if h.filerClientAccessor == nil {
return []string{}
}
var topics []string
h.filerClientAccessor.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.ListEntriesRequest{
Directory: "/topics/kafka",
}
stream, err := client.ListEntries(context.Background(), request)
if err != nil {
return nil // Don't propagate error, just return empty list
}
for {
resp, err := stream.Recv()
if err != nil {
break // End of stream or error
}
if resp.Entry != nil && resp.Entry.IsDirectory {
topics = append(topics, resp.Entry.Name)
} else if resp.Entry != nil {
}
}
return nil
})
return topics
}

View File

@@ -1,53 +0,0 @@
package kafka
import (
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
// Convenience functions for partition mapping used by production code
// The full PartitionMapper implementation is in partition_mapping_test.go for testing
// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range
func MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) {
// Use a range size that divides evenly into MaxPartitionCount (2520)
// Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72
rangeSize := int32(35)
rangeStart = kafkaPartition * rangeSize
rangeStop = rangeStart + rangeSize - 1
return rangeStart, rangeStop
}
// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition
func CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition {
rangeStart, rangeStop := MapKafkaPartitionToSMQRange(kafkaPartition)
return &schema_pb.Partition{
RingSize: pub_balancer.MaxPartitionCount,
RangeStart: rangeStart,
RangeStop: rangeStop,
UnixTimeNs: unixTimeNs,
}
}
// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range
func ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 {
rangeSize := int32(35)
return rangeStart / rangeSize
}
// ValidateKafkaPartition validates that a Kafka partition is within supported range
func ValidateKafkaPartition(kafkaPartition int32) bool {
maxPartitions := int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
return kafkaPartition >= 0 && kafkaPartition < maxPartitions
}
// GetRangeSize returns the range size used for partition mapping
func GetRangeSize() int32 {
return 35
}
// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported
func GetMaxKafkaPartitions() int32 {
return int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
}

View File

@@ -1,294 +0,0 @@
package kafka
import (
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
// PartitionMapper provides consistent Kafka partition to SeaweedMQ ring mapping
// NOTE: This is test-only code and not used in the actual Kafka Gateway implementation
type PartitionMapper struct{}
// NewPartitionMapper creates a new partition mapper
func NewPartitionMapper() *PartitionMapper {
return &PartitionMapper{}
}
// GetRangeSize returns the consistent range size for Kafka partition mapping
// This ensures all components use the same calculation
func (pm *PartitionMapper) GetRangeSize() int32 {
// Use a range size that divides evenly into MaxPartitionCount (2520)
// Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72
// This provides a good balance between partition granularity and ring utilization
return 35
}
// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported
func (pm *PartitionMapper) GetMaxKafkaPartitions() int32 {
// With range size 35, we can support: 2520 / 35 = 72 Kafka partitions
return int32(pub_balancer.MaxPartitionCount) / pm.GetRangeSize()
}
// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range
func (pm *PartitionMapper) MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) {
rangeSize := pm.GetRangeSize()
rangeStart = kafkaPartition * rangeSize
rangeStop = rangeStart + rangeSize - 1
return rangeStart, rangeStop
}
// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition
func (pm *PartitionMapper) CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition {
rangeStart, rangeStop := pm.MapKafkaPartitionToSMQRange(kafkaPartition)
return &schema_pb.Partition{
RingSize: pub_balancer.MaxPartitionCount,
RangeStart: rangeStart,
RangeStop: rangeStop,
UnixTimeNs: unixTimeNs,
}
}
// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range
func (pm *PartitionMapper) ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 {
rangeSize := pm.GetRangeSize()
return rangeStart / rangeSize
}
// ValidateKafkaPartition validates that a Kafka partition is within supported range
func (pm *PartitionMapper) ValidateKafkaPartition(kafkaPartition int32) bool {
return kafkaPartition >= 0 && kafkaPartition < pm.GetMaxKafkaPartitions()
}
// GetPartitionMappingInfo returns debug information about the partition mapping
func (pm *PartitionMapper) GetPartitionMappingInfo() map[string]interface{} {
return map[string]interface{}{
"ring_size": pub_balancer.MaxPartitionCount,
"range_size": pm.GetRangeSize(),
"max_kafka_partitions": pm.GetMaxKafkaPartitions(),
"ring_utilization": float64(pm.GetMaxKafkaPartitions()*pm.GetRangeSize()) / float64(pub_balancer.MaxPartitionCount),
}
}
// Global instance for consistent usage across the test codebase
var DefaultPartitionMapper = NewPartitionMapper()
func TestPartitionMapper_GetRangeSize(t *testing.T) {
mapper := NewPartitionMapper()
rangeSize := mapper.GetRangeSize()
if rangeSize != 35 {
t.Errorf("Expected range size 35, got %d", rangeSize)
}
// Verify that the range size divides evenly into available partitions
maxPartitions := mapper.GetMaxKafkaPartitions()
totalUsed := maxPartitions * rangeSize
if totalUsed > int32(pub_balancer.MaxPartitionCount) {
t.Errorf("Total used slots (%d) exceeds MaxPartitionCount (%d)", totalUsed, pub_balancer.MaxPartitionCount)
}
t.Logf("Range size: %d, Max Kafka partitions: %d, Ring utilization: %.2f%%",
rangeSize, maxPartitions, float64(totalUsed)/float64(pub_balancer.MaxPartitionCount)*100)
}
func TestPartitionMapper_MapKafkaPartitionToSMQRange(t *testing.T) {
mapper := NewPartitionMapper()
tests := []struct {
kafkaPartition int32
expectedStart int32
expectedStop int32
}{
{0, 0, 34},
{1, 35, 69},
{2, 70, 104},
{10, 350, 384},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
start, stop := mapper.MapKafkaPartitionToSMQRange(tt.kafkaPartition)
if start != tt.expectedStart {
t.Errorf("Kafka partition %d: expected start %d, got %d", tt.kafkaPartition, tt.expectedStart, start)
}
if stop != tt.expectedStop {
t.Errorf("Kafka partition %d: expected stop %d, got %d", tt.kafkaPartition, tt.expectedStop, stop)
}
// Verify range size is consistent
rangeSize := stop - start + 1
if rangeSize != mapper.GetRangeSize() {
t.Errorf("Inconsistent range size: expected %d, got %d", mapper.GetRangeSize(), rangeSize)
}
})
}
}
func TestPartitionMapper_ExtractKafkaPartitionFromSMQRange(t *testing.T) {
mapper := NewPartitionMapper()
tests := []struct {
rangeStart int32
expectedKafka int32
}{
{0, 0},
{35, 1},
{70, 2},
{350, 10},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
kafkaPartition := mapper.ExtractKafkaPartitionFromSMQRange(tt.rangeStart)
if kafkaPartition != tt.expectedKafka {
t.Errorf("Range start %d: expected Kafka partition %d, got %d",
tt.rangeStart, tt.expectedKafka, kafkaPartition)
}
})
}
}
func TestPartitionMapper_RoundTrip(t *testing.T) {
mapper := NewPartitionMapper()
// Test round-trip conversion for all valid Kafka partitions
maxPartitions := mapper.GetMaxKafkaPartitions()
for kafkaPartition := int32(0); kafkaPartition < maxPartitions; kafkaPartition++ {
// Kafka -> SMQ -> Kafka
rangeStart, rangeStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
extractedKafka := mapper.ExtractKafkaPartitionFromSMQRange(rangeStart)
if extractedKafka != kafkaPartition {
t.Errorf("Round-trip failed for partition %d: got %d", kafkaPartition, extractedKafka)
}
// Verify no overlap with next partition
if kafkaPartition < maxPartitions-1 {
nextStart, _ := mapper.MapKafkaPartitionToSMQRange(kafkaPartition + 1)
if rangeStop >= nextStart {
t.Errorf("Partition %d range [%d,%d] overlaps with partition %d start %d",
kafkaPartition, rangeStart, rangeStop, kafkaPartition+1, nextStart)
}
}
}
}
func TestPartitionMapper_CreateSMQPartition(t *testing.T) {
mapper := NewPartitionMapper()
kafkaPartition := int32(5)
unixTimeNs := time.Now().UnixNano()
partition := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs)
if partition.RingSize != pub_balancer.MaxPartitionCount {
t.Errorf("Expected ring size %d, got %d", pub_balancer.MaxPartitionCount, partition.RingSize)
}
expectedStart, expectedStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
if partition.RangeStart != expectedStart {
t.Errorf("Expected range start %d, got %d", expectedStart, partition.RangeStart)
}
if partition.RangeStop != expectedStop {
t.Errorf("Expected range stop %d, got %d", expectedStop, partition.RangeStop)
}
if partition.UnixTimeNs != unixTimeNs {
t.Errorf("Expected timestamp %d, got %d", unixTimeNs, partition.UnixTimeNs)
}
}
func TestPartitionMapper_ValidateKafkaPartition(t *testing.T) {
mapper := NewPartitionMapper()
tests := []struct {
partition int32
valid bool
}{
{-1, false},
{0, true},
{1, true},
{mapper.GetMaxKafkaPartitions() - 1, true},
{mapper.GetMaxKafkaPartitions(), false},
{1000, false},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
valid := mapper.ValidateKafkaPartition(tt.partition)
if valid != tt.valid {
t.Errorf("Partition %d: expected valid=%v, got %v", tt.partition, tt.valid, valid)
}
})
}
}
func TestPartitionMapper_ConsistencyWithGlobalFunctions(t *testing.T) {
mapper := NewPartitionMapper()
kafkaPartition := int32(7)
unixTimeNs := time.Now().UnixNano()
// Test that global functions produce same results as mapper methods
start1, stop1 := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
start2, stop2 := MapKafkaPartitionToSMQRange(kafkaPartition)
if start1 != start2 || stop1 != stop2 {
t.Errorf("Global function inconsistent: mapper=(%d,%d), global=(%d,%d)",
start1, stop1, start2, stop2)
}
partition1 := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs)
partition2 := CreateSMQPartition(kafkaPartition, unixTimeNs)
if partition1.RangeStart != partition2.RangeStart || partition1.RangeStop != partition2.RangeStop {
t.Errorf("Global CreateSMQPartition inconsistent")
}
extracted1 := mapper.ExtractKafkaPartitionFromSMQRange(start1)
extracted2 := ExtractKafkaPartitionFromSMQRange(start1)
if extracted1 != extracted2 {
t.Errorf("Global ExtractKafkaPartitionFromSMQRange inconsistent: %d vs %d", extracted1, extracted2)
}
}
func TestPartitionMapper_GetPartitionMappingInfo(t *testing.T) {
mapper := NewPartitionMapper()
info := mapper.GetPartitionMappingInfo()
// Verify all expected keys are present
expectedKeys := []string{"ring_size", "range_size", "max_kafka_partitions", "ring_utilization"}
for _, key := range expectedKeys {
if _, exists := info[key]; !exists {
t.Errorf("Missing key in mapping info: %s", key)
}
}
// Verify values are reasonable
if info["ring_size"].(int) != pub_balancer.MaxPartitionCount {
t.Errorf("Incorrect ring_size in info")
}
if info["range_size"].(int32) != mapper.GetRangeSize() {
t.Errorf("Incorrect range_size in info")
}
utilization := info["ring_utilization"].(float64)
if utilization <= 0 || utilization > 1 {
t.Errorf("Invalid ring utilization: %f", utilization)
}
t.Logf("Partition mapping info: %+v", info)
}

View File

@@ -2,11 +2,9 @@ package offset
import ( import (
"fmt" "fmt"
"os"
"testing" "testing"
"time" "time"
_ "github.com/mattn/go-sqlite3"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
) )
@@ -62,151 +60,6 @@ func BenchmarkBatchOffsetAssignment(b *testing.B) {
} }
} }
// BenchmarkSQLOffsetStorage benchmarks SQL storage operations
func BenchmarkSQLOffsetStorage(b *testing.B) {
// Create temporary database
tmpFile, err := os.CreateTemp("", "benchmark_*.db")
if err != nil {
b.Fatalf("Failed to create temp database: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
b.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
storage, err := NewSQLOffsetStorage(db)
if err != nil {
b.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
partitionKey := partitionKey(partition)
b.Run("SaveCheckpoint", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
storage.SaveCheckpoint("test-namespace", "test-topic", partition, int64(i))
}
})
b.Run("LoadCheckpoint", func(b *testing.B) {
storage.SaveCheckpoint("test-namespace", "test-topic", partition, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
storage.LoadCheckpoint("test-namespace", "test-topic", partition)
}
})
b.Run("SaveOffsetMapping", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100)
}
})
// Pre-populate for read benchmarks
for i := 0; i < 1000; i++ {
storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100)
}
b.Run("GetHighestOffset", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
storage.GetHighestOffset("test-namespace", "test-topic", partition)
}
})
b.Run("LoadOffsetMappings", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
storage.LoadOffsetMappings(partitionKey)
}
})
b.Run("GetOffsetMappingsByRange", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
start := int64(i % 900)
end := start + 100
storage.GetOffsetMappingsByRange(partitionKey, start, end)
}
})
b.Run("GetPartitionStats", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
storage.GetPartitionStats(partitionKey)
}
})
}
// BenchmarkInMemoryVsSQL compares in-memory and SQL storage performance
func BenchmarkInMemoryVsSQL(b *testing.B) {
partition := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
// In-memory storage benchmark
b.Run("InMemory", func(b *testing.B) {
storage := NewInMemoryOffsetStorage()
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
b.Fatalf("Failed to create partition manager: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.AssignOffset()
}
})
// SQL storage benchmark
b.Run("SQL", func(b *testing.B) {
tmpFile, err := os.CreateTemp("", "benchmark_sql_*.db")
if err != nil {
b.Fatalf("Failed to create temp database: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
b.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
storage, err := NewSQLOffsetStorage(db)
if err != nil {
b.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
b.Fatalf("Failed to create partition manager: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.AssignOffset()
}
})
}
// BenchmarkOffsetSubscription benchmarks subscription operations // BenchmarkOffsetSubscription benchmarks subscription operations
func BenchmarkOffsetSubscription(b *testing.B) { func BenchmarkOffsetSubscription(b *testing.B) {
storage := NewInMemoryOffsetStorage() storage := NewInMemoryOffsetStorage()

View File

@@ -1,473 +0,0 @@
package offset
import (
"fmt"
"os"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
// TestEndToEndOffsetFlow tests the complete offset management flow
func TestEndToEndOffsetFlow(t *testing.T) {
// Create temporary database
tmpFile, err := os.CreateTemp("", "e2e_offset_test_*.db")
if err != nil {
t.Fatalf("Failed to create temp database: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())
// Create database with migrations
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
// Create SQL storage
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
// Create SMQ offset integration
integration := NewSMQOffsetIntegration(storage)
// Test partition
partition := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
t.Run("PublishAndAssignOffsets", func(t *testing.T) {
// Simulate publishing messages with offset assignment
records := []PublishRecordRequest{
{Key: []byte("user1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("user2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("user3"), Value: &schema_pb.RecordValue{}},
}
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
if err != nil {
t.Fatalf("Failed to publish record batch: %v", err)
}
if response.BaseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
}
if response.LastOffset != 2 {
t.Errorf("Expected last offset 2, got %d", response.LastOffset)
}
// Verify high water mark
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get high water mark: %v", err)
}
if hwm != 3 {
t.Errorf("Expected high water mark 3, got %d", hwm)
}
})
t.Run("CreateAndUseSubscription", func(t *testing.T) {
// Create subscription from earliest
sub, err := integration.CreateSubscription(
"e2e-test-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Subscribe to records
responses, err := integration.SubscribeRecords(sub, 2)
if err != nil {
t.Fatalf("Failed to subscribe to records: %v", err)
}
if len(responses) != 2 {
t.Errorf("Expected 2 responses, got %d", len(responses))
}
// Check subscription advancement
if sub.CurrentOffset != 2 {
t.Errorf("Expected current offset 2, got %d", sub.CurrentOffset)
}
// Get subscription lag
lag, err := sub.GetLag()
if err != nil {
t.Fatalf("Failed to get lag: %v", err)
}
if lag != 1 { // 3 (hwm) - 2 (current) = 1
t.Errorf("Expected lag 1, got %d", lag)
}
})
t.Run("OffsetSeekingAndRanges", func(t *testing.T) {
// Create subscription at specific offset
sub, err := integration.CreateSubscription(
"seek-test-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_EXACT_OFFSET,
1,
)
if err != nil {
t.Fatalf("Failed to create subscription at offset 1: %v", err)
}
// Verify starting position
if sub.CurrentOffset != 1 {
t.Errorf("Expected current offset 1, got %d", sub.CurrentOffset)
}
// Get offset range
offsetRange, err := sub.GetOffsetRange(2)
if err != nil {
t.Fatalf("Failed to get offset range: %v", err)
}
if offsetRange.StartOffset != 1 {
t.Errorf("Expected start offset 1, got %d", offsetRange.StartOffset)
}
if offsetRange.Count != 2 {
t.Errorf("Expected count 2, got %d", offsetRange.Count)
}
// Seek to different offset
err = sub.SeekToOffset(0)
if err != nil {
t.Fatalf("Failed to seek to offset 0: %v", err)
}
if sub.CurrentOffset != 0 {
t.Errorf("Expected current offset 0 after seek, got %d", sub.CurrentOffset)
}
})
t.Run("PartitionInformationAndMetrics", func(t *testing.T) {
// Get partition offset info
info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get partition offset info: %v", err)
}
if info.EarliestOffset != 0 {
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
}
if info.LatestOffset != 2 {
t.Errorf("Expected latest offset 2, got %d", info.LatestOffset)
}
if info.HighWaterMark != 3 {
t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark)
}
if info.ActiveSubscriptions != 2 { // Two subscriptions created above
t.Errorf("Expected 2 active subscriptions, got %d", info.ActiveSubscriptions)
}
// Get offset metrics
metrics := integration.GetOffsetMetrics()
if metrics.PartitionCount != 1 {
t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
}
if metrics.ActiveSubscriptions != 2 {
t.Errorf("Expected 2 active subscriptions in metrics, got %d", metrics.ActiveSubscriptions)
}
})
}
// TestOffsetPersistenceAcrossRestarts tests that offsets persist across system restarts
func TestOffsetPersistenceAcrossRestarts(t *testing.T) {
// Create temporary database
tmpFile, err := os.CreateTemp("", "persistence_test_*.db")
if err != nil {
t.Fatalf("Failed to create temp database: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())
partition := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
var lastOffset int64
// First session: Create database and assign offsets
{
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
integration := NewSMQOffsetIntegration(storage)
// Publish some records
records := []PublishRecordRequest{
{Key: []byte("msg1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("msg2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("msg3"), Value: &schema_pb.RecordValue{}},
}
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
if err != nil {
t.Fatalf("Failed to publish records: %v", err)
}
lastOffset = response.LastOffset
// Close connections - Close integration first to trigger final checkpoint
integration.Close()
storage.Close()
db.Close()
}
// Second session: Reopen database and verify persistence
{
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
t.Fatalf("Failed to reopen database: %v", err)
}
defer db.Close()
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
integration := NewSMQOffsetIntegration(storage)
// Verify high water mark persisted
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get high water mark after restart: %v", err)
}
if hwm != lastOffset+1 {
t.Errorf("Expected high water mark %d after restart, got %d", lastOffset+1, hwm)
}
// Assign new offsets and verify continuity
newResponse, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte("msg4"), &schema_pb.RecordValue{})
if err != nil {
t.Fatalf("Failed to publish new record after restart: %v", err)
}
expectedNextOffset := lastOffset + 1
if newResponse.BaseOffset != expectedNextOffset {
t.Errorf("Expected next offset %d after restart, got %d", expectedNextOffset, newResponse.BaseOffset)
}
}
}
// TestConcurrentOffsetOperations tests concurrent offset operations
func TestConcurrentOffsetOperations(t *testing.T) {
// Create temporary database
tmpFile, err := os.CreateTemp("", "concurrent_test_*.db")
if err != nil {
t.Fatalf("Failed to create temp database: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
integration := NewSMQOffsetIntegration(storage)
partition := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
// Concurrent publishers
const numPublishers = 5
const recordsPerPublisher = 10
done := make(chan bool, numPublishers)
for i := 0; i < numPublishers; i++ {
go func(publisherID int) {
defer func() { done <- true }()
for j := 0; j < recordsPerPublisher; j++ {
key := fmt.Sprintf("publisher-%d-msg-%d", publisherID, j)
_, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte(key), &schema_pb.RecordValue{})
if err != nil {
t.Errorf("Publisher %d failed to publish message %d: %v", publisherID, j, err)
return
}
}
}(i)
}
// Wait for all publishers to complete
for i := 0; i < numPublishers; i++ {
<-done
}
// Verify total records
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get high water mark: %v", err)
}
expectedTotal := int64(numPublishers * recordsPerPublisher)
if hwm != expectedTotal {
t.Errorf("Expected high water mark %d, got %d", expectedTotal, hwm)
}
// Verify no duplicate offsets
info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get partition info: %v", err)
}
if info.RecordCount != expectedTotal {
t.Errorf("Expected record count %d, got %d", expectedTotal, info.RecordCount)
}
}
// TestOffsetValidationAndErrorHandling tests error conditions and validation
func TestOffsetValidationAndErrorHandling(t *testing.T) {
// Create temporary database
tmpFile, err := os.CreateTemp("", "validation_test_*.db")
if err != nil {
t.Fatalf("Failed to create temp database: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())
db, err := CreateDatabase(tmpFile.Name())
if err != nil {
t.Fatalf("Failed to create database: %v", err)
}
defer db.Close()
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
integration := NewSMQOffsetIntegration(storage)
partition := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
t.Run("InvalidOffsetSubscription", func(t *testing.T) {
// Try to create subscription with invalid offset
_, err := integration.CreateSubscription(
"invalid-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_EXACT_OFFSET,
100, // Beyond any existing data
)
if err == nil {
t.Error("Expected error for subscription beyond high water mark")
}
})
t.Run("NegativeOffsetValidation", func(t *testing.T) {
// Try to create subscription with negative offset
_, err := integration.CreateSubscription(
"negative-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_EXACT_OFFSET,
-1,
)
if err == nil {
t.Error("Expected error for negative offset")
}
})
t.Run("DuplicateSubscriptionID", func(t *testing.T) {
// Create first subscription
_, err := integration.CreateSubscription(
"duplicate-id",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create first subscription: %v", err)
}
// Try to create duplicate
_, err = integration.CreateSubscription(
"duplicate-id",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err == nil {
t.Error("Expected error for duplicate subscription ID")
}
})
t.Run("OffsetRangeValidation", func(t *testing.T) {
// Add some data first
integration.PublishRecord("test-namespace", "test-topic", partition, []byte("test"), &schema_pb.RecordValue{})
// Test invalid range validation
err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10) // Beyond high water mark
if err == nil {
t.Error("Expected error for range beyond high water mark")
}
err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 10, 5) // End before start
if err == nil {
t.Error("Expected error for end offset before start offset")
}
err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, -1, 5) // Negative start
if err == nil {
t.Error("Expected error for negative start offset")
}
})
}

View File

@@ -93,9 +93,3 @@ func (f *FilerOffsetStorage) getPartitionDir(namespace, topicName string, partit
return fmt.Sprintf("%s/%s/%s/%s/%s", filer.TopicsDir, namespace, topicName, version, partitionRange) return fmt.Sprintf("%s/%s/%s/%s/%s", filer.TopicsDir, namespace, topicName, version, partitionRange)
} }
// getPartitionKey generates a unique key for a partition
func (f *FilerOffsetStorage) getPartitionKey(partition *schema_pb.Partition) string {
return fmt.Sprintf("ring:%d:range:%d-%d:time:%d",
partition.RingSize, partition.RangeStart, partition.RangeStop, partition.UnixTimeNs)
}

View File

@@ -1,544 +0,0 @@
package offset
import (
"testing"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
func TestSMQOffsetIntegration_PublishRecord(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish a single record
response, err := integration.PublishRecord(
"test-namespace", "test-topic",
partition,
[]byte("test-key"),
&schema_pb.RecordValue{},
)
if err != nil {
t.Fatalf("Failed to publish record: %v", err)
}
if response.Error != "" {
t.Errorf("Expected no error, got: %s", response.Error)
}
if response.BaseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
}
if response.LastOffset != 0 {
t.Errorf("Expected last offset 0, got %d", response.LastOffset)
}
}
func TestSMQOffsetIntegration_PublishRecordBatch(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Create batch of records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
}
// Publish batch
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
if err != nil {
t.Fatalf("Failed to publish record batch: %v", err)
}
if response.Error != "" {
t.Errorf("Expected no error, got: %s", response.Error)
}
if response.BaseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
}
if response.LastOffset != 2 {
t.Errorf("Expected last offset 2, got %d", response.LastOffset)
}
// Verify high water mark
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get high water mark: %v", err)
}
if hwm != 3 {
t.Errorf("Expected high water mark 3, got %d", hwm)
}
}
func TestSMQOffsetIntegration_EmptyBatch(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish empty batch
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, []PublishRecordRequest{})
if err != nil {
t.Fatalf("Failed to publish empty batch: %v", err)
}
if response.Error == "" {
t.Error("Expected error for empty batch")
}
}
func TestSMQOffsetIntegration_CreateSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish some records first
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Create subscription
sub, err := integration.CreateSubscription(
"test-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
if sub.ID != "test-sub" {
t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID)
}
if sub.StartOffset != 0 {
t.Errorf("Expected start offset 0, got %d", sub.StartOffset)
}
}
func TestSMQOffsetIntegration_SubscribeRecords(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish some records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Create subscription
sub, err := integration.CreateSubscription(
"test-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Subscribe to records
responses, err := integration.SubscribeRecords(sub, 2)
if err != nil {
t.Fatalf("Failed to subscribe to records: %v", err)
}
if len(responses) != 2 {
t.Errorf("Expected 2 responses, got %d", len(responses))
}
// Check offset progression
if responses[0].Offset != 0 {
t.Errorf("Expected first record offset 0, got %d", responses[0].Offset)
}
if responses[1].Offset != 1 {
t.Errorf("Expected second record offset 1, got %d", responses[1].Offset)
}
// Check subscription advancement
if sub.CurrentOffset != 2 {
t.Errorf("Expected subscription current offset 2, got %d", sub.CurrentOffset)
}
}
func TestSMQOffsetIntegration_SubscribeEmptyPartition(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Create subscription on empty partition
sub, err := integration.CreateSubscription(
"empty-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Subscribe to records (should return empty)
responses, err := integration.SubscribeRecords(sub, 10)
if err != nil {
t.Fatalf("Failed to subscribe to empty partition: %v", err)
}
if len(responses) != 0 {
t.Errorf("Expected 0 responses from empty partition, got %d", len(responses))
}
}
func TestSMQOffsetIntegration_SeekSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key4"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key5"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Create subscription
sub, err := integration.CreateSubscription(
"seek-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Seek to offset 3
err = integration.SeekSubscription("seek-sub", 3)
if err != nil {
t.Fatalf("Failed to seek subscription: %v", err)
}
if sub.CurrentOffset != 3 {
t.Errorf("Expected current offset 3 after seek, got %d", sub.CurrentOffset)
}
// Subscribe from new position
responses, err := integration.SubscribeRecords(sub, 2)
if err != nil {
t.Fatalf("Failed to subscribe after seek: %v", err)
}
if len(responses) != 2 {
t.Errorf("Expected 2 responses after seek, got %d", len(responses))
}
if responses[0].Offset != 3 {
t.Errorf("Expected first record offset 3 after seek, got %d", responses[0].Offset)
}
}
func TestSMQOffsetIntegration_GetSubscriptionLag(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Create subscription at offset 1
sub, err := integration.CreateSubscription(
"lag-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_EXACT_OFFSET,
1,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Get lag
lag, err := integration.GetSubscriptionLag("lag-sub")
if err != nil {
t.Fatalf("Failed to get subscription lag: %v", err)
}
expectedLag := int64(3 - 1) // hwm - current
if lag != expectedLag {
t.Errorf("Expected lag %d, got %d", expectedLag, lag)
}
// Advance subscription and check lag again
integration.SubscribeRecords(sub, 1)
lag, err = integration.GetSubscriptionLag("lag-sub")
if err != nil {
t.Fatalf("Failed to get lag after advance: %v", err)
}
expectedLag = int64(3 - 2) // hwm - current
if lag != expectedLag {
t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag)
}
}
func TestSMQOffsetIntegration_CloseSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Create subscription
_, err := integration.CreateSubscription(
"close-sub",
"test-namespace", "test-topic",
partition,
schema_pb.OffsetType_RESET_TO_EARLIEST,
0,
)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Close subscription
err = integration.CloseSubscription("close-sub")
if err != nil {
t.Fatalf("Failed to close subscription: %v", err)
}
// Try to get lag (should fail)
_, err = integration.GetSubscriptionLag("close-sub")
if err == nil {
t.Error("Expected error when getting lag for closed subscription")
}
}
func TestSMQOffsetIntegration_ValidateOffsetRange(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Publish some records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Test valid range
err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 2)
if err != nil {
t.Errorf("Valid range should not return error: %v", err)
}
// Test invalid range (beyond hwm)
err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 5)
if err == nil {
t.Error("Expected error for range beyond high water mark")
}
}
func TestSMQOffsetIntegration_GetAvailableOffsetRange(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Test empty partition
offsetRange, err := integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get available range for empty partition: %v", err)
}
if offsetRange.Count != 0 {
t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count)
}
// Publish records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Test with data
offsetRange, err = integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get available range: %v", err)
}
if offsetRange.StartOffset != 0 {
t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset)
}
if offsetRange.EndOffset != 1 {
t.Errorf("Expected end offset 1, got %d", offsetRange.EndOffset)
}
if offsetRange.Count != 2 {
t.Errorf("Expected count 2, got %d", offsetRange.Count)
}
}
func TestSMQOffsetIntegration_GetOffsetMetrics(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Initial metrics
metrics := integration.GetOffsetMetrics()
if metrics.TotalOffsets != 0 {
t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets)
}
if metrics.ActiveSubscriptions != 0 {
t.Errorf("Expected 0 active subscriptions initially, got %d", metrics.ActiveSubscriptions)
}
// Publish records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Create subscriptions
integration.CreateSubscription("sub1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
integration.CreateSubscription("sub2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
// Check updated metrics
metrics = integration.GetOffsetMetrics()
if metrics.TotalOffsets != 2 {
t.Errorf("Expected 2 total offsets, got %d", metrics.TotalOffsets)
}
if metrics.ActiveSubscriptions != 2 {
t.Errorf("Expected 2 active subscriptions, got %d", metrics.ActiveSubscriptions)
}
if metrics.PartitionCount != 1 {
t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
}
}
func TestSMQOffsetIntegration_GetOffsetInfo(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Test non-existent offset
info, err := integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0)
if err != nil {
t.Fatalf("Failed to get offset info: %v", err)
}
if info.Exists {
t.Error("Offset should not exist in empty partition")
}
// Publish record
integration.PublishRecord("test-namespace", "test-topic", partition, []byte("key1"), &schema_pb.RecordValue{})
// Test existing offset
info, err = integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0)
if err != nil {
t.Fatalf("Failed to get offset info for existing offset: %v", err)
}
if !info.Exists {
t.Error("Offset should exist after publishing")
}
if info.Offset != 0 {
t.Errorf("Expected offset 0, got %d", info.Offset)
}
}
func TestSMQOffsetIntegration_GetPartitionOffsetInfo(t *testing.T) {
storage := NewInMemoryOffsetStorage()
integration := NewSMQOffsetIntegration(storage)
partition := createTestPartition()
// Test empty partition
info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get partition offset info: %v", err)
}
if info.EarliestOffset != 0 {
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
}
if info.LatestOffset != -1 {
t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset)
}
if info.HighWaterMark != 0 {
t.Errorf("Expected high water mark 0, got %d", info.HighWaterMark)
}
if info.RecordCount != 0 {
t.Errorf("Expected record count 0, got %d", info.RecordCount)
}
// Publish records
records := []PublishRecordRequest{
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
}
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
// Create subscription
integration.CreateSubscription("test-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
// Test with data
info, err = integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get partition offset info with data: %v", err)
}
if info.EarliestOffset != 0 {
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
}
if info.LatestOffset != 2 {
t.Errorf("Expected latest offset 2, got %d", info.LatestOffset)
}
if info.HighWaterMark != 3 {
t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark)
}
if info.RecordCount != 3 {
t.Errorf("Expected record count 3, got %d", info.RecordCount)
}
if info.ActiveSubscriptions != 1 {
t.Errorf("Expected 1 active subscription, got %d", info.ActiveSubscriptions)
}
}

View File

@@ -338,13 +338,6 @@ type OffsetAssigner struct {
registry *PartitionOffsetRegistry registry *PartitionOffsetRegistry
} }
// NewOffsetAssigner creates a new offset assigner
func NewOffsetAssigner(storage OffsetStorage) *OffsetAssigner {
return &OffsetAssigner{
registry: NewPartitionOffsetRegistry(storage),
}
}
// AssignSingleOffset assigns a single offset with timestamp // AssignSingleOffset assigns a single offset with timestamp
func (a *OffsetAssigner) AssignSingleOffset(namespace, topicName string, partition *schema_pb.Partition) *AssignmentResult { func (a *OffsetAssigner) AssignSingleOffset(namespace, topicName string, partition *schema_pb.Partition) *AssignmentResult {
offset, err := a.registry.AssignOffset(namespace, topicName, partition) offset, err := a.registry.AssignOffset(namespace, topicName, partition)

View File

@@ -1,388 +0,0 @@
package offset
import (
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
func createTestPartition() *schema_pb.Partition {
return &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
}
func TestPartitionOffsetManager_BasicAssignment(t *testing.T) {
storage := NewInMemoryOffsetStorage()
partition := createTestPartition()
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
t.Fatalf("Failed to create offset manager: %v", err)
}
// Test sequential offset assignment
for i := int64(0); i < 10; i++ {
offset := manager.AssignOffset()
if offset != i {
t.Errorf("Expected offset %d, got %d", i, offset)
}
}
// Test high water mark
hwm := manager.GetHighWaterMark()
if hwm != 10 {
t.Errorf("Expected high water mark 10, got %d", hwm)
}
}
func TestPartitionOffsetManager_BatchAssignment(t *testing.T) {
storage := NewInMemoryOffsetStorage()
partition := createTestPartition()
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
t.Fatalf("Failed to create offset manager: %v", err)
}
// Assign batch of 5 offsets
baseOffset, lastOffset := manager.AssignOffsets(5)
if baseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", baseOffset)
}
if lastOffset != 4 {
t.Errorf("Expected last offset 4, got %d", lastOffset)
}
// Assign another batch
baseOffset, lastOffset = manager.AssignOffsets(3)
if baseOffset != 5 {
t.Errorf("Expected base offset 5, got %d", baseOffset)
}
if lastOffset != 7 {
t.Errorf("Expected last offset 7, got %d", lastOffset)
}
// Check high water mark
hwm := manager.GetHighWaterMark()
if hwm != 8 {
t.Errorf("Expected high water mark 8, got %d", hwm)
}
}
func TestPartitionOffsetManager_Recovery(t *testing.T) {
storage := NewInMemoryOffsetStorage()
partition := createTestPartition()
// Create manager and assign some offsets
manager1, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
t.Fatalf("Failed to create offset manager: %v", err)
}
// Assign offsets and simulate records
for i := 0; i < 150; i++ { // More than checkpoint interval
offset := manager1.AssignOffset()
storage.AddRecord("test-namespace", "test-topic", partition, offset)
}
// Wait for checkpoint to complete
time.Sleep(100 * time.Millisecond)
// Create new manager (simulates restart)
manager2, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
t.Fatalf("Failed to create offset manager after recovery: %v", err)
}
// Next offset should continue from checkpoint + 1
// With checkpoint interval 100, checkpoint happens at offset 100
// So recovery should start from 101, but we assigned 150 offsets (0-149)
// The checkpoint should be at 100, so next offset should be 101
// But since we have records up to 149, it should recover from storage scan
nextOffset := manager2.AssignOffset()
if nextOffset != 150 {
t.Errorf("Expected next offset 150 after recovery, got %d", nextOffset)
}
}
func TestPartitionOffsetManager_RecoveryFromStorage(t *testing.T) {
storage := NewInMemoryOffsetStorage()
partition := createTestPartition()
// Simulate existing records in storage without checkpoint
for i := int64(0); i < 50; i++ {
storage.AddRecord("test-namespace", "test-topic", partition, i)
}
// Create manager - should recover from storage scan
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
if err != nil {
t.Fatalf("Failed to create offset manager: %v", err)
}
// Next offset should be 50
nextOffset := manager.AssignOffset()
if nextOffset != 50 {
t.Errorf("Expected next offset 50 after storage recovery, got %d", nextOffset)
}
}
func TestPartitionOffsetRegistry_MultiplePartitions(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
// Create different partitions
partition1 := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
partition2 := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 32,
RangeStop: 63,
UnixTimeNs: time.Now().UnixNano(),
}
// Assign offsets to different partitions
offset1, err := registry.AssignOffset("test-namespace", "test-topic", partition1)
if err != nil {
t.Fatalf("Failed to assign offset to partition1: %v", err)
}
if offset1 != 0 {
t.Errorf("Expected offset 0 for partition1, got %d", offset1)
}
offset2, err := registry.AssignOffset("test-namespace", "test-topic", partition2)
if err != nil {
t.Fatalf("Failed to assign offset to partition2: %v", err)
}
if offset2 != 0 {
t.Errorf("Expected offset 0 for partition2, got %d", offset2)
}
// Assign more offsets to partition1
offset1_2, err := registry.AssignOffset("test-namespace", "test-topic", partition1)
if err != nil {
t.Fatalf("Failed to assign second offset to partition1: %v", err)
}
if offset1_2 != 1 {
t.Errorf("Expected offset 1 for partition1, got %d", offset1_2)
}
// Partition2 should still be at 0 for next assignment
offset2_2, err := registry.AssignOffset("test-namespace", "test-topic", partition2)
if err != nil {
t.Fatalf("Failed to assign second offset to partition2: %v", err)
}
if offset2_2 != 1 {
t.Errorf("Expected offset 1 for partition2, got %d", offset2_2)
}
}
func TestPartitionOffsetRegistry_BatchAssignment(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
partition := createTestPartition()
// Assign batch of offsets
baseOffset, lastOffset, err := registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
if err != nil {
t.Fatalf("Failed to assign batch offsets: %v", err)
}
if baseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", baseOffset)
}
if lastOffset != 9 {
t.Errorf("Expected last offset 9, got %d", lastOffset)
}
// Get high water mark
hwm, err := registry.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get high water mark: %v", err)
}
if hwm != 10 {
t.Errorf("Expected high water mark 10, got %d", hwm)
}
}
func TestOffsetAssigner_SingleAssignment(t *testing.T) {
storage := NewInMemoryOffsetStorage()
assigner := NewOffsetAssigner(storage)
partition := createTestPartition()
// Assign single offset
result := assigner.AssignSingleOffset("test-namespace", "test-topic", partition)
if result.Error != nil {
t.Fatalf("Failed to assign single offset: %v", result.Error)
}
if result.Assignment == nil {
t.Fatal("Assignment result is nil")
}
if result.Assignment.Offset != 0 {
t.Errorf("Expected offset 0, got %d", result.Assignment.Offset)
}
if result.Assignment.Partition != partition {
t.Error("Partition mismatch in assignment")
}
if result.Assignment.Timestamp <= 0 {
t.Error("Timestamp should be set")
}
}
func TestOffsetAssigner_BatchAssignment(t *testing.T) {
storage := NewInMemoryOffsetStorage()
assigner := NewOffsetAssigner(storage)
partition := createTestPartition()
// Assign batch of offsets
result := assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 5)
if result.Error != nil {
t.Fatalf("Failed to assign batch offsets: %v", result.Error)
}
if result.Batch == nil {
t.Fatal("Batch result is nil")
}
if result.Batch.BaseOffset != 0 {
t.Errorf("Expected base offset 0, got %d", result.Batch.BaseOffset)
}
if result.Batch.LastOffset != 4 {
t.Errorf("Expected last offset 4, got %d", result.Batch.LastOffset)
}
if result.Batch.Count != 5 {
t.Errorf("Expected count 5, got %d", result.Batch.Count)
}
if result.Batch.Timestamp <= 0 {
t.Error("Timestamp should be set")
}
}
func TestOffsetAssigner_HighWaterMark(t *testing.T) {
storage := NewInMemoryOffsetStorage()
assigner := NewOffsetAssigner(storage)
partition := createTestPartition()
// Initially should be 0
hwm, err := assigner.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get initial high water mark: %v", err)
}
if hwm != 0 {
t.Errorf("Expected initial high water mark 0, got %d", hwm)
}
// Assign some offsets
assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 10)
// High water mark should be updated
hwm, err = assigner.GetHighWaterMark("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get high water mark after assignment: %v", err)
}
if hwm != 10 {
t.Errorf("Expected high water mark 10, got %d", hwm)
}
}
func TestPartitionKey(t *testing.T) {
partition1 := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: 1234567890,
}
partition2 := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: 1234567890,
}
partition3 := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 32,
RangeStop: 63,
UnixTimeNs: 1234567890,
}
key1 := partitionKey(partition1)
key2 := partitionKey(partition2)
key3 := partitionKey(partition3)
// Same partitions should have same key
if key1 != key2 {
t.Errorf("Same partitions should have same key: %s vs %s", key1, key2)
}
// Different partitions should have different keys
if key1 == key3 {
t.Errorf("Different partitions should have different keys: %s vs %s", key1, key3)
}
}
func TestConcurrentOffsetAssignment(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
partition := createTestPartition()
const numGoroutines = 10
const offsetsPerGoroutine = 100
results := make(chan int64, numGoroutines*offsetsPerGoroutine)
// Start concurrent offset assignments
for i := 0; i < numGoroutines; i++ {
go func() {
for j := 0; j < offsetsPerGoroutine; j++ {
offset, err := registry.AssignOffset("test-namespace", "test-topic", partition)
if err != nil {
t.Errorf("Failed to assign offset: %v", err)
return
}
results <- offset
}
}()
}
// Collect all results
offsets := make(map[int64]bool)
for i := 0; i < numGoroutines*offsetsPerGoroutine; i++ {
offset := <-results
if offsets[offset] {
t.Errorf("Duplicate offset assigned: %d", offset)
}
offsets[offset] = true
}
// Verify we got all expected offsets
expectedCount := numGoroutines * offsetsPerGoroutine
if len(offsets) != expectedCount {
t.Errorf("Expected %d unique offsets, got %d", expectedCount, len(offsets))
}
// Verify offsets are in expected range
for offset := range offsets {
if offset < 0 || offset >= int64(expectedCount) {
t.Errorf("Offset %d is out of expected range [0, %d)", offset, expectedCount)
}
}
}

View File

@@ -1,302 +0,0 @@
package offset
import (
"database/sql"
"fmt"
"time"
)
// MigrationVersion represents a database migration version
type MigrationVersion struct {
Version int
Description string
SQL string
}
// GetMigrations returns all available migrations for offset storage
func GetMigrations() []MigrationVersion {
return []MigrationVersion{
{
Version: 1,
Description: "Create initial offset storage tables",
SQL: `
-- Partition offset checkpoints table
-- TODO: Add _index as computed column when supported by database
CREATE TABLE IF NOT EXISTS partition_offset_checkpoints (
partition_key TEXT PRIMARY KEY,
ring_size INTEGER NOT NULL,
range_start INTEGER NOT NULL,
range_stop INTEGER NOT NULL,
unix_time_ns INTEGER NOT NULL,
checkpoint_offset INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
-- Offset mappings table for detailed tracking
-- TODO: Add _index as computed column when supported by database
CREATE TABLE IF NOT EXISTS offset_mappings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
partition_key TEXT NOT NULL,
kafka_offset INTEGER NOT NULL,
smq_timestamp INTEGER NOT NULL,
message_size INTEGER NOT NULL,
created_at INTEGER NOT NULL,
UNIQUE(partition_key, kafka_offset)
);
-- Schema migrations tracking table
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at INTEGER NOT NULL
);
`,
},
{
Version: 2,
Description: "Add indexes for performance optimization",
SQL: `
-- Indexes for performance
CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition
ON partition_offset_checkpoints(partition_key);
CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset
ON offset_mappings(partition_key, kafka_offset);
CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp
ON offset_mappings(partition_key, smq_timestamp);
CREATE INDEX IF NOT EXISTS idx_offset_mappings_created_at
ON offset_mappings(created_at);
`,
},
{
Version: 3,
Description: "Add partition metadata table for enhanced tracking",
SQL: `
-- Partition metadata table
CREATE TABLE IF NOT EXISTS partition_metadata (
partition_key TEXT PRIMARY KEY,
ring_size INTEGER NOT NULL,
range_start INTEGER NOT NULL,
range_stop INTEGER NOT NULL,
unix_time_ns INTEGER NOT NULL,
created_at INTEGER NOT NULL,
last_activity_at INTEGER NOT NULL,
record_count INTEGER DEFAULT 0,
total_size INTEGER DEFAULT 0
);
-- Index for partition metadata
CREATE INDEX IF NOT EXISTS idx_partition_metadata_activity
ON partition_metadata(last_activity_at);
`,
},
}
}
// MigrationManager handles database schema migrations
type MigrationManager struct {
db *sql.DB
}
// NewMigrationManager creates a new migration manager
func NewMigrationManager(db *sql.DB) *MigrationManager {
return &MigrationManager{db: db}
}
// GetCurrentVersion returns the current schema version
func (m *MigrationManager) GetCurrentVersion() (int, error) {
// First, ensure the migrations table exists
_, err := m.db.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at INTEGER NOT NULL
)
`)
if err != nil {
return 0, fmt.Errorf("failed to create migrations table: %w", err)
}
var version sql.NullInt64
err = m.db.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
if err != nil {
return 0, fmt.Errorf("failed to get current version: %w", err)
}
if !version.Valid {
return 0, nil // No migrations applied yet
}
return int(version.Int64), nil
}
// ApplyMigrations applies all pending migrations
func (m *MigrationManager) ApplyMigrations() error {
currentVersion, err := m.GetCurrentVersion()
if err != nil {
return fmt.Errorf("failed to get current version: %w", err)
}
migrations := GetMigrations()
for _, migration := range migrations {
if migration.Version <= currentVersion {
continue // Already applied
}
fmt.Printf("Applying migration %d: %s\n", migration.Version, migration.Description)
// Begin transaction
tx, err := m.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
}
// Execute migration SQL
_, err = tx.Exec(migration.SQL)
if err != nil {
tx.Rollback()
return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err)
}
// Record migration as applied
_, err = tx.Exec(
"INSERT INTO schema_migrations (version, description, applied_at) VALUES (?, ?, ?)",
migration.Version,
migration.Description,
getCurrentTimestamp(),
)
if err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
}
// Commit transaction
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
}
fmt.Printf("Successfully applied migration %d\n", migration.Version)
}
return nil
}
// RollbackMigration rolls back a specific migration (if supported)
func (m *MigrationManager) RollbackMigration(version int) error {
// TODO: Implement rollback functionality
// ASSUMPTION: For now, rollbacks are not supported as they require careful planning
return fmt.Errorf("migration rollbacks not implemented - manual intervention required")
}
// GetAppliedMigrations returns a list of all applied migrations
func (m *MigrationManager) GetAppliedMigrations() ([]AppliedMigration, error) {
rows, err := m.db.Query(`
SELECT version, description, applied_at
FROM schema_migrations
ORDER BY version
`)
if err != nil {
return nil, fmt.Errorf("failed to query applied migrations: %w", err)
}
defer rows.Close()
var migrations []AppliedMigration
for rows.Next() {
var migration AppliedMigration
err := rows.Scan(&migration.Version, &migration.Description, &migration.AppliedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan migration: %w", err)
}
migrations = append(migrations, migration)
}
return migrations, nil
}
// ValidateSchema validates that the database schema is up to date
func (m *MigrationManager) ValidateSchema() error {
currentVersion, err := m.GetCurrentVersion()
if err != nil {
return fmt.Errorf("failed to get current version: %w", err)
}
migrations := GetMigrations()
if len(migrations) == 0 {
return nil
}
latestVersion := migrations[len(migrations)-1].Version
if currentVersion < latestVersion {
return fmt.Errorf("schema is outdated: current version %d, latest version %d", currentVersion, latestVersion)
}
return nil
}
// AppliedMigration represents a migration that has been applied
type AppliedMigration struct {
Version int
Description string
AppliedAt int64
}
// getCurrentTimestamp returns the current timestamp in nanoseconds
func getCurrentTimestamp() int64 {
return time.Now().UnixNano()
}
// CreateDatabase creates and initializes a new offset storage database
func CreateDatabase(dbPath string) (*sql.DB, error) {
// TODO: Support different database types (PostgreSQL, MySQL, etc.)
// ASSUMPTION: Using SQLite for now, can be extended for other databases
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Configure SQLite for better performance
pragmas := []string{
"PRAGMA journal_mode=WAL", // Write-Ahead Logging for better concurrency
"PRAGMA synchronous=NORMAL", // Balance between safety and performance
"PRAGMA cache_size=10000", // Increase cache size
"PRAGMA foreign_keys=ON", // Enable foreign key constraints
"PRAGMA temp_store=MEMORY", // Store temporary tables in memory
}
for _, pragma := range pragmas {
_, err := db.Exec(pragma)
if err != nil {
db.Close()
return nil, fmt.Errorf("failed to set pragma %s: %w", pragma, err)
}
}
// Apply migrations
migrationManager := NewMigrationManager(db)
err = migrationManager.ApplyMigrations()
if err != nil {
db.Close()
return nil, fmt.Errorf("failed to apply migrations: %w", err)
}
return db, nil
}
// BackupDatabase creates a backup of the offset storage database
func BackupDatabase(sourceDB *sql.DB, backupPath string) error {
// TODO: Implement database backup functionality
// ASSUMPTION: This would use database-specific backup mechanisms
return fmt.Errorf("database backup not implemented yet")
}
// RestoreDatabase restores a database from a backup
func RestoreDatabase(backupPath, targetPath string) error {
// TODO: Implement database restore functionality
// ASSUMPTION: This would use database-specific restore mechanisms
return fmt.Errorf("database restore not implemented yet")
}

View File

@@ -1,394 +0,0 @@
package offset
import (
"database/sql"
"fmt"
"time"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
// OffsetEntry represents a mapping between Kafka offset and SMQ timestamp
type OffsetEntry struct {
KafkaOffset int64
SMQTimestamp int64
MessageSize int32
}
// SQLOffsetStorage implements OffsetStorage using SQL database with _index column
type SQLOffsetStorage struct {
db *sql.DB
}
// NewSQLOffsetStorage creates a new SQL-based offset storage
func NewSQLOffsetStorage(db *sql.DB) (*SQLOffsetStorage, error) {
storage := &SQLOffsetStorage{db: db}
// Initialize database schema
if err := storage.initializeSchema(); err != nil {
return nil, fmt.Errorf("failed to initialize schema: %w", err)
}
return storage, nil
}
// initializeSchema creates the necessary tables for offset storage
func (s *SQLOffsetStorage) initializeSchema() error {
// TODO: Create offset storage tables with _index as hidden column
// ASSUMPTION: Using SQLite-compatible syntax, may need adaptation for other databases
queries := []string{
// Partition offset checkpoints table
// TODO: Add _index as computed column when supported by database
// ASSUMPTION: Using regular columns for now, _index concept preserved for future enhancement
`CREATE TABLE IF NOT EXISTS partition_offset_checkpoints (
partition_key TEXT PRIMARY KEY,
ring_size INTEGER NOT NULL,
range_start INTEGER NOT NULL,
range_stop INTEGER NOT NULL,
unix_time_ns INTEGER NOT NULL,
checkpoint_offset INTEGER NOT NULL,
updated_at INTEGER NOT NULL
)`,
// Offset mappings table for detailed tracking
// TODO: Add _index as computed column when supported by database
`CREATE TABLE IF NOT EXISTS offset_mappings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
partition_key TEXT NOT NULL,
kafka_offset INTEGER NOT NULL,
smq_timestamp INTEGER NOT NULL,
message_size INTEGER NOT NULL,
created_at INTEGER NOT NULL,
UNIQUE(partition_key, kafka_offset)
)`,
// Indexes for performance
`CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition
ON partition_offset_checkpoints(partition_key)`,
`CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset
ON offset_mappings(partition_key, kafka_offset)`,
`CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp
ON offset_mappings(partition_key, smq_timestamp)`,
}
for _, query := range queries {
if _, err := s.db.Exec(query); err != nil {
return fmt.Errorf("failed to execute schema query: %w", err)
}
}
return nil
}
// SaveCheckpoint saves the checkpoint for a partition
func (s *SQLOffsetStorage) SaveCheckpoint(namespace, topicName string, partition *schema_pb.Partition, offset int64) error {
// Use TopicPartitionKey to ensure each topic has isolated checkpoint storage
partitionKey := TopicPartitionKey(namespace, topicName, partition)
now := time.Now().UnixNano()
// TODO: Use UPSERT for better performance
// ASSUMPTION: SQLite REPLACE syntax, may need adaptation for other databases
query := `
REPLACE INTO partition_offset_checkpoints
(partition_key, ring_size, range_start, range_stop, unix_time_ns, checkpoint_offset, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
`
_, err := s.db.Exec(query,
partitionKey,
partition.RingSize,
partition.RangeStart,
partition.RangeStop,
partition.UnixTimeNs,
offset,
now,
)
if err != nil {
return fmt.Errorf("failed to save checkpoint: %w", err)
}
return nil
}
// LoadCheckpoint loads the checkpoint for a partition
func (s *SQLOffsetStorage) LoadCheckpoint(namespace, topicName string, partition *schema_pb.Partition) (int64, error) {
// Use TopicPartitionKey to match SaveCheckpoint
partitionKey := TopicPartitionKey(namespace, topicName, partition)
query := `
SELECT checkpoint_offset
FROM partition_offset_checkpoints
WHERE partition_key = ?
`
var checkpointOffset int64
err := s.db.QueryRow(query, partitionKey).Scan(&checkpointOffset)
if err == sql.ErrNoRows {
return -1, fmt.Errorf("no checkpoint found")
}
if err != nil {
return -1, fmt.Errorf("failed to load checkpoint: %w", err)
}
return checkpointOffset, nil
}
// GetHighestOffset finds the highest offset in storage for a partition
func (s *SQLOffsetStorage) GetHighestOffset(namespace, topicName string, partition *schema_pb.Partition) (int64, error) {
// Use TopicPartitionKey to match SaveCheckpoint
partitionKey := TopicPartitionKey(namespace, topicName, partition)
// TODO: Use _index column for efficient querying
// ASSUMPTION: kafka_offset represents the sequential offset we're tracking
query := `
SELECT MAX(kafka_offset)
FROM offset_mappings
WHERE partition_key = ?
`
var highestOffset sql.NullInt64
err := s.db.QueryRow(query, partitionKey).Scan(&highestOffset)
if err != nil {
return -1, fmt.Errorf("failed to get highest offset: %w", err)
}
if !highestOffset.Valid {
return -1, fmt.Errorf("no records found")
}
return highestOffset.Int64, nil
}
// SaveOffsetMapping stores an offset mapping (extends OffsetStorage interface)
func (s *SQLOffsetStorage) SaveOffsetMapping(partitionKey string, kafkaOffset, smqTimestamp int64, size int32) error {
now := time.Now().UnixNano()
// TODO: Handle duplicate key conflicts gracefully
// ASSUMPTION: Using INSERT OR REPLACE for conflict resolution
query := `
INSERT OR REPLACE INTO offset_mappings
(partition_key, kafka_offset, smq_timestamp, message_size, created_at)
VALUES (?, ?, ?, ?, ?)
`
_, err := s.db.Exec(query, partitionKey, kafkaOffset, smqTimestamp, size, now)
if err != nil {
return fmt.Errorf("failed to save offset mapping: %w", err)
}
return nil
}
// LoadOffsetMappings retrieves all offset mappings for a partition
func (s *SQLOffsetStorage) LoadOffsetMappings(partitionKey string) ([]OffsetEntry, error) {
// TODO: Add pagination for large result sets
// ASSUMPTION: Loading all mappings for now, should be paginated in production
query := `
SELECT kafka_offset, smq_timestamp, message_size
FROM offset_mappings
WHERE partition_key = ?
ORDER BY kafka_offset ASC
`
rows, err := s.db.Query(query, partitionKey)
if err != nil {
return nil, fmt.Errorf("failed to query offset mappings: %w", err)
}
defer rows.Close()
var entries []OffsetEntry
for rows.Next() {
var entry OffsetEntry
err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize)
if err != nil {
return nil, fmt.Errorf("failed to scan offset entry: %w", err)
}
entries = append(entries, entry)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating offset mappings: %w", err)
}
return entries, nil
}
// GetOffsetMappingsByRange retrieves offset mappings within a specific range
func (s *SQLOffsetStorage) GetOffsetMappingsByRange(partitionKey string, startOffset, endOffset int64) ([]OffsetEntry, error) {
// TODO: Use _index column for efficient range queries
query := `
SELECT kafka_offset, smq_timestamp, message_size
FROM offset_mappings
WHERE partition_key = ? AND kafka_offset >= ? AND kafka_offset <= ?
ORDER BY kafka_offset ASC
`
rows, err := s.db.Query(query, partitionKey, startOffset, endOffset)
if err != nil {
return nil, fmt.Errorf("failed to query offset range: %w", err)
}
defer rows.Close()
var entries []OffsetEntry
for rows.Next() {
var entry OffsetEntry
err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize)
if err != nil {
return nil, fmt.Errorf("failed to scan offset entry: %w", err)
}
entries = append(entries, entry)
}
return entries, nil
}
// GetPartitionStats returns statistics about a partition's offset usage
func (s *SQLOffsetStorage) GetPartitionStats(partitionKey string) (*PartitionStats, error) {
query := `
SELECT
COUNT(*) as record_count,
MIN(kafka_offset) as earliest_offset,
MAX(kafka_offset) as latest_offset,
SUM(message_size) as total_size,
MIN(created_at) as first_record_time,
MAX(created_at) as last_record_time
FROM offset_mappings
WHERE partition_key = ?
`
var stats PartitionStats
var earliestOffset, latestOffset sql.NullInt64
var totalSize sql.NullInt64
var firstRecordTime, lastRecordTime sql.NullInt64
err := s.db.QueryRow(query, partitionKey).Scan(
&stats.RecordCount,
&earliestOffset,
&latestOffset,
&totalSize,
&firstRecordTime,
&lastRecordTime,
)
if err != nil {
return nil, fmt.Errorf("failed to get partition stats: %w", err)
}
stats.PartitionKey = partitionKey
if earliestOffset.Valid {
stats.EarliestOffset = earliestOffset.Int64
} else {
stats.EarliestOffset = -1
}
if latestOffset.Valid {
stats.LatestOffset = latestOffset.Int64
stats.HighWaterMark = latestOffset.Int64 + 1
} else {
stats.LatestOffset = -1
stats.HighWaterMark = 0
}
if firstRecordTime.Valid {
stats.FirstRecordTime = firstRecordTime.Int64
}
if lastRecordTime.Valid {
stats.LastRecordTime = lastRecordTime.Int64
}
if totalSize.Valid {
stats.TotalSize = totalSize.Int64
}
return &stats, nil
}
// CleanupOldMappings removes offset mappings older than the specified time
func (s *SQLOffsetStorage) CleanupOldMappings(olderThanNs int64) error {
// TODO: Add configurable cleanup policies
// ASSUMPTION: Simple time-based cleanup, could be enhanced with retention policies
query := `
DELETE FROM offset_mappings
WHERE created_at < ?
`
result, err := s.db.Exec(query, olderThanNs)
if err != nil {
return fmt.Errorf("failed to cleanup old mappings: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected > 0 {
// Log cleanup activity
fmt.Printf("Cleaned up %d old offset mappings\n", rowsAffected)
}
return nil
}
// Close closes the database connection
func (s *SQLOffsetStorage) Close() error {
if s.db != nil {
return s.db.Close()
}
return nil
}
// PartitionStats provides statistics about a partition's offset usage
type PartitionStats struct {
PartitionKey string
RecordCount int64
EarliestOffset int64
LatestOffset int64
HighWaterMark int64
TotalSize int64
FirstRecordTime int64
LastRecordTime int64
}
// GetAllPartitions returns a list of all partitions with offset data
func (s *SQLOffsetStorage) GetAllPartitions() ([]string, error) {
query := `
SELECT DISTINCT partition_key
FROM offset_mappings
ORDER BY partition_key
`
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to get all partitions: %w", err)
}
defer rows.Close()
var partitions []string
for rows.Next() {
var partitionKey string
if err := rows.Scan(&partitionKey); err != nil {
return nil, fmt.Errorf("failed to scan partition key: %w", err)
}
partitions = append(partitions, partitionKey)
}
return partitions, nil
}
// Vacuum performs database maintenance operations
func (s *SQLOffsetStorage) Vacuum() error {
// TODO: Add database-specific optimization commands
// ASSUMPTION: SQLite VACUUM command, may need adaptation for other databases
_, err := s.db.Exec("VACUUM")
if err != nil {
return fmt.Errorf("failed to vacuum database: %w", err)
}
return nil
}

View File

@@ -1,516 +0,0 @@
package offset
import (
"database/sql"
"os"
"testing"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
func createTestDB(t *testing.T) *sql.DB {
// Create temporary database file
tmpFile, err := os.CreateTemp("", "offset_test_*.db")
if err != nil {
t.Fatalf("Failed to create temp database file: %v", err)
}
tmpFile.Close()
// Clean up the file when test completes
t.Cleanup(func() {
os.Remove(tmpFile.Name())
})
db, err := sql.Open("sqlite3", tmpFile.Name())
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
t.Cleanup(func() {
db.Close()
})
return db
}
func createTestPartitionForSQL() *schema_pb.Partition {
return &schema_pb.Partition{
RingSize: 1024,
RangeStart: 0,
RangeStop: 31,
UnixTimeNs: time.Now().UnixNano(),
}
}
func TestSQLOffsetStorage_InitializeSchema(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
// Verify tables were created
tables := []string{
"partition_offset_checkpoints",
"offset_mappings",
}
for _, table := range tables {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&count)
if err != nil {
t.Fatalf("Failed to check table %s: %v", table, err)
}
if count != 1 {
t.Errorf("Table %s was not created", table)
}
}
}
func TestSQLOffsetStorage_SaveLoadCheckpoint(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
// Test saving checkpoint
err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 100)
if err != nil {
t.Fatalf("Failed to save checkpoint: %v", err)
}
// Test loading checkpoint
checkpoint, err := storage.LoadCheckpoint("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to load checkpoint: %v", err)
}
if checkpoint != 100 {
t.Errorf("Expected checkpoint 100, got %d", checkpoint)
}
// Test updating checkpoint
err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 200)
if err != nil {
t.Fatalf("Failed to update checkpoint: %v", err)
}
checkpoint, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to load updated checkpoint: %v", err)
}
if checkpoint != 200 {
t.Errorf("Expected updated checkpoint 200, got %d", checkpoint)
}
}
func TestSQLOffsetStorage_LoadCheckpointNotFound(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
// Test loading non-existent checkpoint
_, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition)
if err == nil {
t.Error("Expected error for non-existent checkpoint")
}
}
func TestSQLOffsetStorage_SaveLoadOffsetMappings(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
partitionKey := partitionKey(partition)
// Save multiple offset mappings
mappings := []struct {
offset int64
timestamp int64
size int32
}{
{0, 1000, 100},
{1, 2000, 150},
{2, 3000, 200},
}
for _, mapping := range mappings {
err := storage.SaveOffsetMapping(partitionKey, mapping.offset, mapping.timestamp, mapping.size)
if err != nil {
t.Fatalf("Failed to save offset mapping: %v", err)
}
}
// Load offset mappings
entries, err := storage.LoadOffsetMappings(partitionKey)
if err != nil {
t.Fatalf("Failed to load offset mappings: %v", err)
}
if len(entries) != len(mappings) {
t.Errorf("Expected %d entries, got %d", len(mappings), len(entries))
}
// Verify entries are sorted by offset
for i, entry := range entries {
expected := mappings[i]
if entry.KafkaOffset != expected.offset {
t.Errorf("Entry %d: expected offset %d, got %d", i, expected.offset, entry.KafkaOffset)
}
if entry.SMQTimestamp != expected.timestamp {
t.Errorf("Entry %d: expected timestamp %d, got %d", i, expected.timestamp, entry.SMQTimestamp)
}
if entry.MessageSize != expected.size {
t.Errorf("Entry %d: expected size %d, got %d", i, expected.size, entry.MessageSize)
}
}
}
func TestSQLOffsetStorage_GetHighestOffset(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
partitionKey := TopicPartitionKey("test-namespace", "test-topic", partition)
// Test empty partition
_, err = storage.GetHighestOffset("test-namespace", "test-topic", partition)
if err == nil {
t.Error("Expected error for empty partition")
}
// Add some offset mappings
offsets := []int64{5, 1, 3, 2, 4}
for _, offset := range offsets {
err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100)
if err != nil {
t.Fatalf("Failed to save offset mapping: %v", err)
}
}
// Get highest offset
highest, err := storage.GetHighestOffset("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get highest offset: %v", err)
}
if highest != 5 {
t.Errorf("Expected highest offset 5, got %d", highest)
}
}
func TestSQLOffsetStorage_GetOffsetMappingsByRange(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
partitionKey := partitionKey(partition)
// Add offset mappings
for i := int64(0); i < 10; i++ {
err := storage.SaveOffsetMapping(partitionKey, i, i*1000, 100)
if err != nil {
t.Fatalf("Failed to save offset mapping: %v", err)
}
}
// Get range of offsets
entries, err := storage.GetOffsetMappingsByRange(partitionKey, 3, 7)
if err != nil {
t.Fatalf("Failed to get offset range: %v", err)
}
expectedCount := 5 // offsets 3, 4, 5, 6, 7
if len(entries) != expectedCount {
t.Errorf("Expected %d entries, got %d", expectedCount, len(entries))
}
// Verify range
for i, entry := range entries {
expectedOffset := int64(3 + i)
if entry.KafkaOffset != expectedOffset {
t.Errorf("Entry %d: expected offset %d, got %d", i, expectedOffset, entry.KafkaOffset)
}
}
}
func TestSQLOffsetStorage_GetPartitionStats(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
partitionKey := partitionKey(partition)
// Test empty partition stats
stats, err := storage.GetPartitionStats(partitionKey)
if err != nil {
t.Fatalf("Failed to get empty partition stats: %v", err)
}
if stats.RecordCount != 0 {
t.Errorf("Expected record count 0, got %d", stats.RecordCount)
}
if stats.EarliestOffset != -1 {
t.Errorf("Expected earliest offset -1, got %d", stats.EarliestOffset)
}
// Add some data
sizes := []int32{100, 150, 200}
for i, size := range sizes {
err := storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), size)
if err != nil {
t.Fatalf("Failed to save offset mapping: %v", err)
}
}
// Get stats with data
stats, err = storage.GetPartitionStats(partitionKey)
if err != nil {
t.Fatalf("Failed to get partition stats: %v", err)
}
if stats.RecordCount != 3 {
t.Errorf("Expected record count 3, got %d", stats.RecordCount)
}
if stats.EarliestOffset != 0 {
t.Errorf("Expected earliest offset 0, got %d", stats.EarliestOffset)
}
if stats.LatestOffset != 2 {
t.Errorf("Expected latest offset 2, got %d", stats.LatestOffset)
}
if stats.HighWaterMark != 3 {
t.Errorf("Expected high water mark 3, got %d", stats.HighWaterMark)
}
expectedTotalSize := int64(100 + 150 + 200)
if stats.TotalSize != expectedTotalSize {
t.Errorf("Expected total size %d, got %d", expectedTotalSize, stats.TotalSize)
}
}
func TestSQLOffsetStorage_GetAllPartitions(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
// Test empty database
partitions, err := storage.GetAllPartitions()
if err != nil {
t.Fatalf("Failed to get all partitions: %v", err)
}
if len(partitions) != 0 {
t.Errorf("Expected 0 partitions, got %d", len(partitions))
}
// Add data for multiple partitions
partition1 := createTestPartitionForSQL()
partition2 := &schema_pb.Partition{
RingSize: 1024,
RangeStart: 32,
RangeStop: 63,
UnixTimeNs: time.Now().UnixNano(),
}
partitionKey1 := partitionKey(partition1)
partitionKey2 := partitionKey(partition2)
storage.SaveOffsetMapping(partitionKey1, 0, 1000, 100)
storage.SaveOffsetMapping(partitionKey2, 0, 2000, 150)
// Get all partitions
partitions, err = storage.GetAllPartitions()
if err != nil {
t.Fatalf("Failed to get all partitions: %v", err)
}
if len(partitions) != 2 {
t.Errorf("Expected 2 partitions, got %d", len(partitions))
}
// Verify partition keys are present
partitionMap := make(map[string]bool)
for _, p := range partitions {
partitionMap[p] = true
}
if !partitionMap[partitionKey1] {
t.Errorf("Partition key %s not found", partitionKey1)
}
if !partitionMap[partitionKey2] {
t.Errorf("Partition key %s not found", partitionKey2)
}
}
func TestSQLOffsetStorage_CleanupOldMappings(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
partitionKey := partitionKey(partition)
// Add mappings with different timestamps
now := time.Now().UnixNano()
// Add old mapping by directly inserting with old timestamp
oldTime := now - (24 * time.Hour).Nanoseconds() // 24 hours ago
_, err = db.Exec(`
INSERT INTO offset_mappings
(partition_key, kafka_offset, smq_timestamp, message_size, created_at)
VALUES (?, ?, ?, ?, ?)
`, partitionKey, 0, oldTime, 100, oldTime)
if err != nil {
t.Fatalf("Failed to insert old mapping: %v", err)
}
// Add recent mapping
storage.SaveOffsetMapping(partitionKey, 1, now, 150)
// Verify both mappings exist
entries, err := storage.LoadOffsetMappings(partitionKey)
if err != nil {
t.Fatalf("Failed to load mappings: %v", err)
}
if len(entries) != 2 {
t.Errorf("Expected 2 mappings before cleanup, got %d", len(entries))
}
// Cleanup old mappings (older than 12 hours)
cutoffTime := now - (12 * time.Hour).Nanoseconds()
err = storage.CleanupOldMappings(cutoffTime)
if err != nil {
t.Fatalf("Failed to cleanup old mappings: %v", err)
}
// Verify only recent mapping remains
entries, err = storage.LoadOffsetMappings(partitionKey)
if err != nil {
t.Fatalf("Failed to load mappings after cleanup: %v", err)
}
if len(entries) != 1 {
t.Errorf("Expected 1 mapping after cleanup, got %d", len(entries))
}
if entries[0].KafkaOffset != 1 {
t.Errorf("Expected remaining mapping offset 1, got %d", entries[0].KafkaOffset)
}
}
func TestSQLOffsetStorage_Vacuum(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
// Vacuum should not fail on empty database
err = storage.Vacuum()
if err != nil {
t.Fatalf("Failed to vacuum database: %v", err)
}
// Add some data and vacuum again
partition := createTestPartitionForSQL()
partitionKey := partitionKey(partition)
storage.SaveOffsetMapping(partitionKey, 0, 1000, 100)
err = storage.Vacuum()
if err != nil {
t.Fatalf("Failed to vacuum database with data: %v", err)
}
}
func TestSQLOffsetStorage_ConcurrentAccess(t *testing.T) {
db := createTestDB(t)
storage, err := NewSQLOffsetStorage(db)
if err != nil {
t.Fatalf("Failed to create SQL storage: %v", err)
}
defer storage.Close()
partition := createTestPartitionForSQL()
partitionKey := partitionKey(partition)
// Test concurrent writes
const numGoroutines = 10
const offsetsPerGoroutine = 10
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(goroutineID int) {
defer func() { done <- true }()
for j := 0; j < offsetsPerGoroutine; j++ {
offset := int64(goroutineID*offsetsPerGoroutine + j)
err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100)
if err != nil {
t.Errorf("Failed to save offset mapping %d: %v", offset, err)
return
}
}
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < numGoroutines; i++ {
<-done
}
// Verify all mappings were saved
entries, err := storage.LoadOffsetMappings(partitionKey)
if err != nil {
t.Fatalf("Failed to load mappings: %v", err)
}
expectedCount := numGoroutines * offsetsPerGoroutine
if len(entries) != expectedCount {
t.Errorf("Expected %d mappings, got %d", expectedCount, len(entries))
}
}

View File

@@ -1,457 +0,0 @@
package offset
import (
"testing"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
func TestOffsetSubscriber_CreateSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign some offsets first
registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
// Test EXACT_OFFSET subscription
sub, err := subscriber.CreateSubscription("test-sub-1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
if err != nil {
t.Fatalf("Failed to create EXACT_OFFSET subscription: %v", err)
}
if sub.StartOffset != 5 {
t.Errorf("Expected start offset 5, got %d", sub.StartOffset)
}
if sub.CurrentOffset != 5 {
t.Errorf("Expected current offset 5, got %d", sub.CurrentOffset)
}
// Test RESET_TO_LATEST subscription
sub2, err := subscriber.CreateSubscription("test-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
if err != nil {
t.Fatalf("Failed to create RESET_TO_LATEST subscription: %v", err)
}
if sub2.StartOffset != 10 { // Should be at high water mark
t.Errorf("Expected start offset 10, got %d", sub2.StartOffset)
}
}
func TestOffsetSubscriber_InvalidSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign some offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 5)
// Test invalid offset (beyond high water mark)
_, err := subscriber.CreateSubscription("invalid-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 10)
if err == nil {
t.Error("Expected error for offset beyond high water mark")
}
// Test negative offset
_, err = subscriber.CreateSubscription("invalid-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, -1)
if err == nil {
t.Error("Expected error for negative offset")
}
}
func TestOffsetSubscriber_DuplicateSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Create first subscription
_, err := subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
if err != nil {
t.Fatalf("Failed to create first subscription: %v", err)
}
// Try to create duplicate
_, err = subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
if err == nil {
t.Error("Expected error for duplicate subscription ID")
}
}
func TestOffsetSubscription_SeekToOffset(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 20)
// Create subscription
sub, err := subscriber.CreateSubscription("seek-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Test valid seek
err = sub.SeekToOffset(10)
if err != nil {
t.Fatalf("Failed to seek to offset 10: %v", err)
}
if sub.CurrentOffset != 10 {
t.Errorf("Expected current offset 10, got %d", sub.CurrentOffset)
}
// Test invalid seek (beyond high water mark)
err = sub.SeekToOffset(25)
if err == nil {
t.Error("Expected error for seek beyond high water mark")
}
// Test negative seek
err = sub.SeekToOffset(-1)
if err == nil {
t.Error("Expected error for negative seek offset")
}
}
func TestOffsetSubscription_AdvanceOffset(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Create subscription
sub, err := subscriber.CreateSubscription("advance-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Test single advance
initialOffset := sub.GetNextOffset()
sub.AdvanceOffset()
if sub.GetNextOffset() != initialOffset+1 {
t.Errorf("Expected offset %d, got %d", initialOffset+1, sub.GetNextOffset())
}
// Test batch advance
sub.AdvanceOffsetBy(5)
if sub.GetNextOffset() != initialOffset+6 {
t.Errorf("Expected offset %d, got %d", initialOffset+6, sub.GetNextOffset())
}
}
func TestOffsetSubscription_GetLag(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 15)
// Create subscription at offset 5
sub, err := subscriber.CreateSubscription("lag-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Check initial lag
lag, err := sub.GetLag()
if err != nil {
t.Fatalf("Failed to get lag: %v", err)
}
expectedLag := int64(15 - 5) // hwm - current
if lag != expectedLag {
t.Errorf("Expected lag %d, got %d", expectedLag, lag)
}
// Advance and check lag again
sub.AdvanceOffsetBy(3)
lag, err = sub.GetLag()
if err != nil {
t.Fatalf("Failed to get lag after advance: %v", err)
}
expectedLag = int64(15 - 8) // hwm - current
if lag != expectedLag {
t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag)
}
}
func TestOffsetSubscription_IsAtEnd(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
// Create subscription at end
sub, err := subscriber.CreateSubscription("end-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Should be at end
atEnd, err := sub.IsAtEnd()
if err != nil {
t.Fatalf("Failed to check if at end: %v", err)
}
if !atEnd {
t.Error("Expected subscription to be at end")
}
// Seek to middle and check again
sub.SeekToOffset(5)
atEnd, err = sub.IsAtEnd()
if err != nil {
t.Fatalf("Failed to check if at end after seek: %v", err)
}
if atEnd {
t.Error("Expected subscription not to be at end after seek")
}
}
func TestOffsetSubscription_GetOffsetRange(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 20)
// Create subscription
sub, err := subscriber.CreateSubscription("range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Test normal range
offsetRange, err := sub.GetOffsetRange(10)
if err != nil {
t.Fatalf("Failed to get offset range: %v", err)
}
if offsetRange.StartOffset != 5 {
t.Errorf("Expected start offset 5, got %d", offsetRange.StartOffset)
}
if offsetRange.EndOffset != 14 {
t.Errorf("Expected end offset 14, got %d", offsetRange.EndOffset)
}
if offsetRange.Count != 10 {
t.Errorf("Expected count 10, got %d", offsetRange.Count)
}
// Test range that exceeds high water mark
sub.SeekToOffset(15)
offsetRange, err = sub.GetOffsetRange(10)
if err != nil {
t.Fatalf("Failed to get offset range near end: %v", err)
}
if offsetRange.StartOffset != 15 {
t.Errorf("Expected start offset 15, got %d", offsetRange.StartOffset)
}
if offsetRange.EndOffset != 19 { // Should be capped at hwm-1
t.Errorf("Expected end offset 19, got %d", offsetRange.EndOffset)
}
if offsetRange.Count != 5 {
t.Errorf("Expected count 5, got %d", offsetRange.Count)
}
}
func TestOffsetSubscription_EmptyRange(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Assign offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
// Create subscription at end
sub, err := subscriber.CreateSubscription("empty-range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Request range when at end
offsetRange, err := sub.GetOffsetRange(5)
if err != nil {
t.Fatalf("Failed to get offset range at end: %v", err)
}
if offsetRange.Count != 0 {
t.Errorf("Expected empty range (count 0), got count %d", offsetRange.Count)
}
if offsetRange.StartOffset != 10 {
t.Errorf("Expected start offset 10, got %d", offsetRange.StartOffset)
}
if offsetRange.EndOffset != 9 { // Empty range: end < start
t.Errorf("Expected end offset 9 (empty range), got %d", offsetRange.EndOffset)
}
}
func TestOffsetSeeker_ValidateOffsetRange(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
seeker := NewOffsetSeeker(registry)
partition := createTestPartition()
// Assign offsets
registry.AssignOffsets("test-namespace", "test-topic", partition, 15)
// Test valid range
err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10)
if err != nil {
t.Errorf("Valid range should not return error: %v", err)
}
// Test invalid ranges
testCases := []struct {
name string
startOffset int64
endOffset int64
expectError bool
}{
{"negative start", -1, 5, true},
{"end before start", 10, 5, true},
{"start beyond hwm", 20, 25, true},
{"valid range", 0, 14, false},
{"single offset", 5, 5, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, tc.startOffset, tc.endOffset)
if tc.expectError && err == nil {
t.Error("Expected error but got none")
}
if !tc.expectError && err != nil {
t.Errorf("Expected no error but got: %v", err)
}
})
}
}
func TestOffsetSeeker_GetAvailableOffsetRange(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
seeker := NewOffsetSeeker(registry)
partition := createTestPartition()
// Test empty partition
offsetRange, err := seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get available range for empty partition: %v", err)
}
if offsetRange.Count != 0 {
t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count)
}
// Assign offsets and test again
registry.AssignOffsets("test-namespace", "test-topic", partition, 25)
offsetRange, err = seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
if err != nil {
t.Fatalf("Failed to get available range: %v", err)
}
if offsetRange.StartOffset != 0 {
t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset)
}
if offsetRange.EndOffset != 24 {
t.Errorf("Expected end offset 24, got %d", offsetRange.EndOffset)
}
if offsetRange.Count != 25 {
t.Errorf("Expected count 25, got %d", offsetRange.Count)
}
}
func TestOffsetSubscriber_CloseSubscription(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Create subscription
sub, err := subscriber.CreateSubscription("close-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
// Verify subscription exists
_, err = subscriber.GetSubscription("close-test")
if err != nil {
t.Fatalf("Subscription should exist: %v", err)
}
// Close subscription
err = subscriber.CloseSubscription("close-test")
if err != nil {
t.Fatalf("Failed to close subscription: %v", err)
}
// Verify subscription is gone
_, err = subscriber.GetSubscription("close-test")
if err == nil {
t.Error("Subscription should not exist after close")
}
// Verify subscription is marked inactive
if sub.IsActive {
t.Error("Subscription should be marked inactive after close")
}
}
func TestOffsetSubscription_InactiveOperations(t *testing.T) {
storage := NewInMemoryOffsetStorage()
registry := NewPartitionOffsetRegistry(storage)
subscriber := NewOffsetSubscriber(registry)
partition := createTestPartition()
// Create and close subscription
sub, err := subscriber.CreateSubscription("inactive-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
if err != nil {
t.Fatalf("Failed to create subscription: %v", err)
}
subscriber.CloseSubscription("inactive-test")
// Test operations on inactive subscription
err = sub.SeekToOffset(5)
if err == nil {
t.Error("Expected error for seek on inactive subscription")
}
_, err = sub.GetLag()
if err == nil {
t.Error("Expected error for GetLag on inactive subscription")
}
_, err = sub.IsAtEnd()
if err == nil {
t.Error("Expected error for IsAtEnd on inactive subscription")
}
_, err = sub.GetOffsetRange(10)
if err == nil {
t.Error("Expected error for GetOffsetRange on inactive subscription")
}
}

View File

@@ -1,13 +1,6 @@
package pub_balancer package pub_balancer
import ( import ()
"math/rand/v2"
"sort"
cmap "github.com/orcaman/concurrent-map/v2"
"github.com/seaweedfs/seaweedfs/weed/mq/topic"
"modernc.org/mathutil"
)
func (balancer *PubBalancer) RepairTopics() []BalanceAction { func (balancer *PubBalancer) RepairTopics() []BalanceAction {
action := BalanceTopicPartitionOnBrokers(balancer.Brokers) action := BalanceTopicPartitionOnBrokers(balancer.Brokers)
@@ -17,107 +10,3 @@ func (balancer *PubBalancer) RepairTopics() []BalanceAction {
type TopicPartitionInfo struct { type TopicPartitionInfo struct {
Broker string Broker string
} }
// RepairMissingTopicPartitions check the stats of all brokers,
// and repair the missing topic partitions on the brokers.
func RepairMissingTopicPartitions(brokers cmap.ConcurrentMap[string, *BrokerStats]) (actions []BalanceAction) {
// find all topic partitions
topicToTopicPartitions := make(map[topic.Topic]map[topic.Partition]*TopicPartitionInfo)
for brokerStatsItem := range brokers.IterBuffered() {
broker, brokerStats := brokerStatsItem.Key, brokerStatsItem.Val
for topicPartitionStatsItem := range brokerStats.TopicPartitionStats.IterBuffered() {
topicPartitionStat := topicPartitionStatsItem.Val
topicPartitionToInfo, found := topicToTopicPartitions[topicPartitionStat.Topic]
if !found {
topicPartitionToInfo = make(map[topic.Partition]*TopicPartitionInfo)
topicToTopicPartitions[topicPartitionStat.Topic] = topicPartitionToInfo
}
tpi, found := topicPartitionToInfo[topicPartitionStat.Partition]
if !found {
tpi = &TopicPartitionInfo{}
topicPartitionToInfo[topicPartitionStat.Partition] = tpi
}
tpi.Broker = broker
}
}
// collect all brokers as candidates
candidates := make([]string, 0, brokers.Count())
for brokerStatsItem := range brokers.IterBuffered() {
candidates = append(candidates, brokerStatsItem.Key)
}
// find the missing topic partitions
for t, topicPartitionToInfo := range topicToTopicPartitions {
missingPartitions := EachTopicRepairMissingTopicPartitions(t, topicPartitionToInfo)
for _, partition := range missingPartitions {
actions = append(actions, BalanceActionCreate{
TopicPartition: topic.TopicPartition{
Topic: t,
Partition: partition,
},
TargetBroker: candidates[rand.IntN(len(candidates))],
})
}
}
return actions
}
func EachTopicRepairMissingTopicPartitions(t topic.Topic, info map[topic.Partition]*TopicPartitionInfo) (missingPartitions []topic.Partition) {
// find the missing topic partitions
var partitions []topic.Partition
for partition := range info {
partitions = append(partitions, partition)
}
return findMissingPartitions(partitions, MaxPartitionCount)
}
// findMissingPartitions find the missing partitions
func findMissingPartitions(partitions []topic.Partition, ringSize int32) (missingPartitions []topic.Partition) {
// sort the partitions by range start
sort.Slice(partitions, func(i, j int) bool {
return partitions[i].RangeStart < partitions[j].RangeStart
})
// calculate the average partition size
var covered int32
for _, partition := range partitions {
covered += partition.RangeStop - partition.RangeStart
}
averagePartitionSize := covered / int32(len(partitions))
// find the missing partitions
var coveredWatermark int32
i := 0
for i < len(partitions) {
partition := partitions[i]
if partition.RangeStart > coveredWatermark {
upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, partition.RangeStart)
missingPartitions = append(missingPartitions, topic.Partition{
RangeStart: coveredWatermark,
RangeStop: upperBound,
RingSize: ringSize,
})
coveredWatermark = upperBound
if coveredWatermark == partition.RangeStop {
i++
}
} else {
coveredWatermark = partition.RangeStop
i++
}
}
for coveredWatermark < ringSize {
upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, ringSize)
missingPartitions = append(missingPartitions, topic.Partition{
RangeStart: coveredWatermark,
RangeStop: upperBound,
RingSize: ringSize,
})
coveredWatermark = upperBound
}
return missingPartitions
}

View File

@@ -1,98 +0,0 @@
package pub_balancer
import (
"reflect"
"testing"
"github.com/seaweedfs/seaweedfs/weed/mq/topic"
)
func Test_findMissingPartitions(t *testing.T) {
type args struct {
partitions []topic.Partition
}
tests := []struct {
name string
args args
wantMissingPartitions []topic.Partition
}{
{
name: "one partition",
args: args{
partitions: []topic.Partition{
{RingSize: 1024, RangeStart: 0, RangeStop: 1024},
},
},
wantMissingPartitions: nil,
},
{
name: "two partitions",
args: args{
partitions: []topic.Partition{
{RingSize: 1024, RangeStart: 0, RangeStop: 512},
{RingSize: 1024, RangeStart: 512, RangeStop: 1024},
},
},
wantMissingPartitions: nil,
},
{
name: "four partitions, missing last two",
args: args{
partitions: []topic.Partition{
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
},
},
wantMissingPartitions: []topic.Partition{
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
},
},
{
name: "four partitions, missing first two",
args: args{
partitions: []topic.Partition{
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
},
},
wantMissingPartitions: []topic.Partition{
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
},
},
{
name: "four partitions, missing middle two",
args: args{
partitions: []topic.Partition{
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
},
},
wantMissingPartitions: []topic.Partition{
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
},
},
{
name: "four partitions, missing three",
args: args{
partitions: []topic.Partition{
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
},
},
wantMissingPartitions: []topic.Partition{
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotMissingPartitions := findMissingPartitions(tt.args.partitions, 1024); !reflect.DeepEqual(gotMissingPartitions, tt.wantMissingPartitions) {
t.Errorf("findMissingPartitions() = %v, want %v", gotMissingPartitions, tt.wantMissingPartitions)
}
})
}
}

View File

@@ -1,109 +0,0 @@
package segment
import (
flatbuffers "github.com/google/flatbuffers/go"
"github.com/seaweedfs/seaweedfs/weed/pb/message_fbs"
)
type MessageBatchBuilder struct {
b *flatbuffers.Builder
producerId int32
producerEpoch int32
segmentId int32
flags int32
messageOffsets []flatbuffers.UOffsetT
segmentSeqBase int64
segmentSeqLast int64
tsMsBase int64
tsMsLast int64
}
func NewMessageBatchBuilder(b *flatbuffers.Builder,
producerId int32,
producerEpoch int32,
segmentId int32,
flags int32) *MessageBatchBuilder {
b.Reset()
return &MessageBatchBuilder{
b: b,
producerId: producerId,
producerEpoch: producerEpoch,
segmentId: segmentId,
flags: flags,
}
}
func (builder *MessageBatchBuilder) AddMessage(segmentSeq int64, tsMs int64, properties map[string][]byte, key []byte, value []byte) {
if builder.segmentSeqBase == 0 {
builder.segmentSeqBase = segmentSeq
}
builder.segmentSeqLast = segmentSeq
if builder.tsMsBase == 0 {
builder.tsMsBase = tsMs
}
builder.tsMsLast = tsMs
var names, values, pairs []flatbuffers.UOffsetT
for k, v := range properties {
names = append(names, builder.b.CreateString(k))
values = append(values, builder.b.CreateByteVector(v))
}
for i, _ := range names {
message_fbs.NameValueStart(builder.b)
message_fbs.NameValueAddName(builder.b, names[i])
message_fbs.NameValueAddValue(builder.b, values[i])
pair := message_fbs.NameValueEnd(builder.b)
pairs = append(pairs, pair)
}
message_fbs.MessageStartPropertiesVector(builder.b, len(properties))
for i := len(pairs) - 1; i >= 0; i-- {
builder.b.PrependUOffsetT(pairs[i])
}
propOffset := builder.b.EndVector(len(properties))
keyOffset := builder.b.CreateByteVector(key)
valueOffset := builder.b.CreateByteVector(value)
message_fbs.MessageStart(builder.b)
message_fbs.MessageAddSeqDelta(builder.b, int32(segmentSeq-builder.segmentSeqBase))
message_fbs.MessageAddTsMsDelta(builder.b, int32(tsMs-builder.tsMsBase))
message_fbs.MessageAddProperties(builder.b, propOffset)
message_fbs.MessageAddKey(builder.b, keyOffset)
message_fbs.MessageAddData(builder.b, valueOffset)
messageOffset := message_fbs.MessageEnd(builder.b)
builder.messageOffsets = append(builder.messageOffsets, messageOffset)
}
func (builder *MessageBatchBuilder) BuildMessageBatch() {
message_fbs.MessageBatchStartMessagesVector(builder.b, len(builder.messageOffsets))
for i := len(builder.messageOffsets) - 1; i >= 0; i-- {
builder.b.PrependUOffsetT(builder.messageOffsets[i])
}
messagesOffset := builder.b.EndVector(len(builder.messageOffsets))
message_fbs.MessageBatchStart(builder.b)
message_fbs.MessageBatchAddProducerId(builder.b, builder.producerId)
message_fbs.MessageBatchAddProducerEpoch(builder.b, builder.producerEpoch)
message_fbs.MessageBatchAddSegmentId(builder.b, builder.segmentId)
message_fbs.MessageBatchAddFlags(builder.b, builder.flags)
message_fbs.MessageBatchAddSegmentSeqBase(builder.b, builder.segmentSeqBase)
message_fbs.MessageBatchAddSegmentSeqMaxDelta(builder.b, int32(builder.segmentSeqLast-builder.segmentSeqBase))
message_fbs.MessageBatchAddTsMsBase(builder.b, builder.tsMsBase)
message_fbs.MessageBatchAddTsMsMaxDelta(builder.b, int32(builder.tsMsLast-builder.tsMsBase))
message_fbs.MessageBatchAddMessages(builder.b, messagesOffset)
messageBatch := message_fbs.MessageBatchEnd(builder.b)
builder.b.Finish(messageBatch)
}
func (builder *MessageBatchBuilder) GetBytes() []byte {
return builder.b.FinishedBytes()
}

View File

@@ -1,61 +0,0 @@
package segment
import (
"testing"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/seaweedfs/seaweedfs/weed/pb/message_fbs"
"github.com/stretchr/testify/assert"
)
func TestMessageSerde(t *testing.T) {
b := flatbuffers.NewBuilder(1024)
prop := make(map[string][]byte)
prop["n1"] = []byte("v1")
prop["n2"] = []byte("v2")
bb := NewMessageBatchBuilder(b, 1, 2, 3, 4)
bb.AddMessage(5, 6, prop, []byte("the primary key"), []byte("body is here"))
bb.AddMessage(5, 7, prop, []byte("the primary 2"), []byte("body is 2"))
bb.BuildMessageBatch()
buf := bb.GetBytes()
println("serialized size", len(buf))
mb := message_fbs.GetRootAsMessageBatch(buf, 0)
assert.Equal(t, int32(1), mb.ProducerId())
assert.Equal(t, int32(2), mb.ProducerEpoch())
assert.Equal(t, int32(3), mb.SegmentId())
assert.Equal(t, int32(4), mb.Flags())
assert.Equal(t, int64(5), mb.SegmentSeqBase())
assert.Equal(t, int32(0), mb.SegmentSeqMaxDelta())
assert.Equal(t, int64(6), mb.TsMsBase())
assert.Equal(t, int32(1), mb.TsMsMaxDelta())
assert.Equal(t, 2, mb.MessagesLength())
m := &message_fbs.Message{}
mb.Messages(m, 0)
/*
// the vector seems not consistent
nv := &message_fbs.NameValue{}
m.Properties(nv, 0)
assert.Equal(t, "n1", string(nv.Name()))
assert.Equal(t, "v1", string(nv.Value()))
m.Properties(nv, 1)
assert.Equal(t, "n2", string(nv.Name()))
assert.Equal(t, "v2", string(nv.Value()))
*/
assert.Equal(t, []byte("the primary key"), m.Key())
assert.Equal(t, []byte("body is here"), m.Data())
assert.Equal(t, int32(0), m.SeqDelta())
assert.Equal(t, int32(0), m.TsMsDelta())
}

View File

@@ -28,28 +28,6 @@ func (imt *InflightMessageTracker) EnflightMessage(key []byte, tsNs int64) {
imt.timestamps.EnflightTimestamp(tsNs) imt.timestamps.EnflightTimestamp(tsNs)
} }
// IsMessageAcknowledged returns true if the message has been acknowledged.
// If the message is older than the oldest inflight messages, returns false.
// returns false if the message is inflight.
// Otherwise, returns false if the message is old and can be ignored.
func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) bool {
imt.mu.Lock()
defer imt.mu.Unlock()
if tsNs <= imt.timestamps.OldestAckedTimestamp() {
return true
}
if tsNs > imt.timestamps.Latest() {
return false
}
if _, found := imt.messages[string(key)]; found {
return false
}
return true
}
// AcknowledgeMessage acknowledges the message with the key and timestamp. // AcknowledgeMessage acknowledges the message with the key and timestamp.
func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool { func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool {
// fmt.Printf("AcknowledgeMessage(%s,%d)\n", string(key), tsNs) // fmt.Printf("AcknowledgeMessage(%s,%d)\n", string(key), tsNs)
@@ -164,8 +142,3 @@ func (rb *RingBuffer) AckTimestamp(timestamp int64) {
func (rb *RingBuffer) OldestAckedTimestamp() int64 { func (rb *RingBuffer) OldestAckedTimestamp() int64 {
return rb.maxAllAckedTs return rb.maxAllAckedTs
} }
// Latest returns the most recently known timestamp in the ring buffer.
func (rb *RingBuffer) Latest() int64 {
return rb.maxTimestamp
}

View File

@@ -1,134 +0,0 @@
package sub_coordinator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRingBuffer(t *testing.T) {
// Initialize a RingBuffer with capacity 5
rb := NewRingBuffer(5)
// Add timestamps to the buffer
timestamps := []int64{100, 200, 300, 400, 500}
for _, ts := range timestamps {
rb.EnflightTimestamp(ts)
}
// Test Add method and buffer size
expectedSize := 5
if rb.size != expectedSize {
t.Errorf("Expected buffer size %d, got %d", expectedSize, rb.size)
}
assert.Equal(t, int64(0), rb.OldestAckedTimestamp())
assert.Equal(t, int64(500), rb.Latest())
rb.AckTimestamp(200)
assert.Equal(t, int64(0), rb.OldestAckedTimestamp())
rb.AckTimestamp(100)
assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
rb.EnflightTimestamp(int64(600))
rb.EnflightTimestamp(int64(700))
rb.AckTimestamp(500)
assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
rb.AckTimestamp(400)
assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
rb.AckTimestamp(300)
assert.Equal(t, int64(500), rb.OldestAckedTimestamp())
assert.Equal(t, int64(700), rb.Latest())
}
func TestInflightMessageTracker(t *testing.T) {
// Initialize an InflightMessageTracker with capacity 5
tracker := NewInflightMessageTracker(5)
// Add inflight messages
key := []byte("1")
timestamp := int64(1)
tracker.EnflightMessage(key, timestamp)
// Test IsMessageAcknowledged method
isOld := tracker.IsMessageAcknowledged(key, timestamp-10)
if !isOld {
t.Error("Expected message to be old")
}
// Test AcknowledgeMessage method
acked := tracker.AcknowledgeMessage(key, timestamp)
if !acked {
t.Error("Expected message to be acked")
}
if _, exists := tracker.messages[string(key)]; exists {
t.Error("Expected message to be deleted after ack")
}
if tracker.timestamps.size != 0 {
t.Error("Expected buffer size to be 0 after ack")
}
assert.Equal(t, timestamp, tracker.GetOldestAckedTimestamp())
}
func TestInflightMessageTracker2(t *testing.T) {
// Initialize an InflightMessageTracker with initial capacity 1
tracker := NewInflightMessageTracker(1)
tracker.EnflightMessage([]byte("1"), int64(1))
tracker.EnflightMessage([]byte("2"), int64(2))
tracker.EnflightMessage([]byte("3"), int64(3))
tracker.EnflightMessage([]byte("4"), int64(4))
tracker.EnflightMessage([]byte("5"), int64(5))
assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
assert.Equal(t, int64(1), tracker.GetOldestAckedTimestamp())
// Test IsMessageAcknowledged method
isAcked := tracker.IsMessageAcknowledged([]byte("2"), int64(2))
if isAcked {
t.Error("Expected message to be not acked")
}
// Test AcknowledgeMessage method
assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
assert.Equal(t, int64(2), tracker.GetOldestAckedTimestamp())
}
func TestInflightMessageTracker3(t *testing.T) {
// Initialize an InflightMessageTracker with initial capacity 1
tracker := NewInflightMessageTracker(1)
tracker.EnflightMessage([]byte("1"), int64(1))
tracker.EnflightMessage([]byte("2"), int64(2))
tracker.EnflightMessage([]byte("3"), int64(3))
assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
tracker.EnflightMessage([]byte("4"), int64(4))
tracker.EnflightMessage([]byte("5"), int64(5))
assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3)))
tracker.EnflightMessage([]byte("6"), int64(6))
tracker.EnflightMessage([]byte("7"), int64(7))
assert.True(t, tracker.AcknowledgeMessage([]byte("4"), int64(4)))
assert.True(t, tracker.AcknowledgeMessage([]byte("5"), int64(5)))
assert.True(t, tracker.AcknowledgeMessage([]byte("6"), int64(6)))
assert.Equal(t, int64(6), tracker.GetOldestAckedTimestamp())
assert.True(t, tracker.AcknowledgeMessage([]byte("7"), int64(7)))
assert.Equal(t, int64(7), tracker.GetOldestAckedTimestamp())
}
func TestInflightMessageTracker4(t *testing.T) {
// Initialize an InflightMessageTracker with initial capacity 1
tracker := NewInflightMessageTracker(1)
tracker.EnflightMessage([]byte("1"), int64(1))
tracker.EnflightMessage([]byte("2"), int64(2))
assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
tracker.EnflightMessage([]byte("3"), int64(3))
assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3)))
assert.Equal(t, int64(3), tracker.GetOldestAckedTimestamp())
}

View File

@@ -1,130 +1,6 @@
package sub_coordinator package sub_coordinator
import (
"fmt"
"time"
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
)
type PartitionConsumerMapping struct { type PartitionConsumerMapping struct {
currentMapping *PartitionSlotToConsumerInstanceList currentMapping *PartitionSlotToConsumerInstanceList
prevMappings []*PartitionSlotToConsumerInstanceList prevMappings []*PartitionSlotToConsumerInstanceList
} }
// Balance goal:
// 1. max processing power utilization
// 2. allow one consumer instance to be down unexpectedly
// without affecting the processing power utilization
func (pcm *PartitionConsumerMapping) BalanceToConsumerInstances(partitionSlotToBrokerList *pub_balancer.PartitionSlotToBrokerList, consumerInstances []*ConsumerGroupInstance) {
if len(partitionSlotToBrokerList.PartitionSlots) == 0 || len(consumerInstances) == 0 {
return
}
newMapping := NewPartitionSlotToConsumerInstanceList(partitionSlotToBrokerList.RingSize, time.Now())
var prevMapping *PartitionSlotToConsumerInstanceList
if len(pcm.prevMappings) > 0 {
prevMapping = pcm.prevMappings[len(pcm.prevMappings)-1]
} else {
prevMapping = nil
}
newMapping.PartitionSlots = doBalanceSticky(partitionSlotToBrokerList.PartitionSlots, consumerInstances, prevMapping)
if pcm.currentMapping != nil {
pcm.prevMappings = append(pcm.prevMappings, pcm.currentMapping)
if len(pcm.prevMappings) > 10 {
pcm.prevMappings = pcm.prevMappings[1:]
}
}
pcm.currentMapping = newMapping
}
func doBalanceSticky(partitions []*pub_balancer.PartitionSlotToBroker, consumerInstances []*ConsumerGroupInstance, prevMapping *PartitionSlotToConsumerInstanceList) (partitionSlots []*PartitionSlotToConsumerInstance) {
// collect previous consumer instance ids
prevConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
if prevMapping != nil {
for _, prevPartitionSlot := range prevMapping.PartitionSlots {
if prevPartitionSlot.AssignedInstanceId != "" {
prevConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId] = struct{}{}
}
}
}
// collect current consumer instance ids
currConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
for _, consumerInstance := range consumerInstances {
currConsumerInstanceIds[consumerInstance.InstanceId] = struct{}{}
}
// check deleted consumer instances
deletedConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
for consumerInstanceId := range prevConsumerInstanceIds {
if _, ok := currConsumerInstanceIds[consumerInstanceId]; !ok {
deletedConsumerInstanceIds[consumerInstanceId] = struct{}{}
}
}
// convert partition slots from list to a map
prevPartitionSlotMap := make(map[string]*PartitionSlotToConsumerInstance)
if prevMapping != nil {
for _, partitionSlot := range prevMapping.PartitionSlots {
key := fmt.Sprintf("%d-%d", partitionSlot.RangeStart, partitionSlot.RangeStop)
prevPartitionSlotMap[key] = partitionSlot
}
}
// make a copy of old mapping, skipping the deleted consumer instances
newPartitionSlots := make([]*PartitionSlotToConsumerInstance, 0, len(partitions))
for _, partition := range partitions {
newPartitionSlots = append(newPartitionSlots, &PartitionSlotToConsumerInstance{
RangeStart: partition.RangeStart,
RangeStop: partition.RangeStop,
UnixTimeNs: partition.UnixTimeNs,
Broker: partition.AssignedBroker,
FollowerBroker: partition.FollowerBroker,
})
}
for _, newPartitionSlot := range newPartitionSlots {
key := fmt.Sprintf("%d-%d", newPartitionSlot.RangeStart, newPartitionSlot.RangeStop)
if prevPartitionSlot, ok := prevPartitionSlotMap[key]; ok {
if _, ok := deletedConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId]; !ok {
newPartitionSlot.AssignedInstanceId = prevPartitionSlot.AssignedInstanceId
}
}
}
// for all consumer instances, count the average number of partitions
// that are assigned to them
consumerInstancePartitionCount := make(map[ConsumerGroupInstanceId]int)
for _, newPartitionSlot := range newPartitionSlots {
if newPartitionSlot.AssignedInstanceId != "" {
consumerInstancePartitionCount[newPartitionSlot.AssignedInstanceId]++
}
}
// average number of partitions that are assigned to each consumer instance
averageConsumerInstanceLoad := float32(len(partitions)) / float32(len(consumerInstances))
// assign unassigned partition slots to consumer instances that is underloaded
consumerInstanceIdsIndex := 0
for _, newPartitionSlot := range newPartitionSlots {
if newPartitionSlot.AssignedInstanceId == "" {
for avoidDeadLoop := len(consumerInstances); avoidDeadLoop > 0; avoidDeadLoop-- {
consumerInstance := consumerInstances[consumerInstanceIdsIndex]
if float32(consumerInstancePartitionCount[consumerInstance.InstanceId]) < averageConsumerInstanceLoad {
newPartitionSlot.AssignedInstanceId = consumerInstance.InstanceId
consumerInstancePartitionCount[consumerInstance.InstanceId]++
consumerInstanceIdsIndex++
if consumerInstanceIdsIndex >= len(consumerInstances) {
consumerInstanceIdsIndex = 0
}
break
} else {
consumerInstanceIdsIndex++
if consumerInstanceIdsIndex >= len(consumerInstances) {
consumerInstanceIdsIndex = 0
}
}
}
}
}
return newPartitionSlots
}

View File

@@ -1,385 +0,0 @@
package sub_coordinator
import (
"reflect"
"testing"
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
)
func Test_doBalanceSticky(t *testing.T) {
type args struct {
partitions []*pub_balancer.PartitionSlotToBroker
consumerInstanceIds []*ConsumerGroupInstance
prevMapping *PartitionSlotToConsumerInstanceList
}
tests := []struct {
name string
args args
wantPartitionSlots []*PartitionSlotToConsumerInstance
}{
{
name: "1 consumer instance, 1 partition",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 100,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
},
prevMapping: nil,
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-1",
},
},
},
{
name: "2 consumer instances, 1 partition",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 100,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-2",
MaxPartitionCount: 1,
},
},
prevMapping: nil,
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-1",
},
},
},
{
name: "1 consumer instance, 2 partitions",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 50,
},
{
RangeStart: 50,
RangeStop: 100,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
},
prevMapping: nil,
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-1",
},
},
},
{
name: "2 consumer instances, 2 partitions",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 50,
},
{
RangeStart: 50,
RangeStop: 100,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-2",
MaxPartitionCount: 1,
},
},
prevMapping: nil,
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
{
name: "2 consumer instances, 2 partitions, 1 deleted consumer instance",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 50,
},
{
RangeStart: 50,
RangeStop: 100,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-2",
MaxPartitionCount: 1,
},
},
prevMapping: &PartitionSlotToConsumerInstanceList{
PartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-3",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
{
name: "2 consumer instances, 2 partitions, 1 new consumer instance",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 50,
},
{
RangeStart: 50,
RangeStop: 100,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-2",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-3",
MaxPartitionCount: 1,
},
},
prevMapping: &PartitionSlotToConsumerInstanceList{
PartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-3",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-3",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
{
name: "2 consumer instances, 2 partitions, 1 new partition",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 50,
},
{
RangeStart: 50,
RangeStop: 100,
},
{
RangeStart: 100,
RangeStop: 150,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-2",
MaxPartitionCount: 1,
},
},
prevMapping: &PartitionSlotToConsumerInstanceList{
PartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
{
RangeStart: 100,
RangeStop: 150,
AssignedInstanceId: "consumer-instance-1",
},
},
},
{
name: "2 consumer instances, 2 partitions, 1 new partition, 1 new consumer instance",
args: args{
partitions: []*pub_balancer.PartitionSlotToBroker{
{
RangeStart: 0,
RangeStop: 50,
},
{
RangeStart: 50,
RangeStop: 100,
},
{
RangeStart: 100,
RangeStop: 150,
},
},
consumerInstanceIds: []*ConsumerGroupInstance{
{
InstanceId: "consumer-instance-1",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-2",
MaxPartitionCount: 1,
},
{
InstanceId: "consumer-instance-3",
MaxPartitionCount: 1,
},
},
prevMapping: &PartitionSlotToConsumerInstanceList{
PartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
},
},
},
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
{
RangeStart: 0,
RangeStop: 50,
AssignedInstanceId: "consumer-instance-1",
},
{
RangeStart: 50,
RangeStop: 100,
AssignedInstanceId: "consumer-instance-2",
},
{
RangeStart: 100,
RangeStop: 150,
AssignedInstanceId: "consumer-instance-3",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotPartitionSlots := doBalanceSticky(tt.args.partitions, tt.args.consumerInstanceIds, tt.args.prevMapping); !reflect.DeepEqual(gotPartitionSlots, tt.wantPartitionSlots) {
t.Errorf("doBalanceSticky() = %v, want %v", gotPartitionSlots, tt.wantPartitionSlots)
}
})
}
}

View File

@@ -1,7 +1,5 @@
package sub_coordinator package sub_coordinator
import "time"
type PartitionSlotToConsumerInstance struct { type PartitionSlotToConsumerInstance struct {
RangeStart int32 RangeStart int32
RangeStop int32 RangeStop int32
@@ -16,10 +14,3 @@ type PartitionSlotToConsumerInstanceList struct {
RingSize int32 RingSize int32
Version int64 Version int64
} }
func NewPartitionSlotToConsumerInstanceList(ringSize int32, version time.Time) *PartitionSlotToConsumerInstanceList {
return &PartitionSlotToConsumerInstanceList{
RingSize: ringSize,
Version: version.UnixNano(),
}
}

View File

@@ -90,22 +90,3 @@ type OffsetAwarePublisher struct {
partition *LocalPartition partition *LocalPartition
assignOffsetFn OffsetAssignmentFunc assignOffsetFn OffsetAssignmentFunc
} }
// NewOffsetAwarePublisher creates a new offset-aware publisher
func NewOffsetAwarePublisher(partition *LocalPartition, assignOffsetFn OffsetAssignmentFunc) *OffsetAwarePublisher {
return &OffsetAwarePublisher{
partition: partition,
assignOffsetFn: assignOffsetFn,
}
}
// Publish publishes a message with automatic offset assignment
func (oap *OffsetAwarePublisher) Publish(message *mq_pb.DataMessage) error {
_, err := oap.partition.PublishWithOffset(message, oap.assignOffsetFn)
return err
}
// GetPartition returns the underlying partition
func (oap *OffsetAwarePublisher) GetPartition() *LocalPartition {
return oap.partition
}

View File

@@ -16,15 +16,6 @@ type Partition struct {
UnixTimeNs int64 // in nanoseconds UnixTimeNs int64 // in nanoseconds
} }
func NewPartition(rangeStart, rangeStop, ringSize int32, unixTimeNs int64) *Partition {
return &Partition{
RangeStart: rangeStart,
RangeStop: rangeStop,
RingSize: ringSize,
UnixTimeNs: unixTimeNs,
}
}
func (partition Partition) Equals(other Partition) bool { func (partition Partition) Equals(other Partition) bool {
if partition.RangeStart != other.RangeStart { if partition.RangeStart != other.RangeStart {
return false return false
@@ -57,24 +48,6 @@ func FromPbPartition(partition *schema_pb.Partition) Partition {
} }
} }
func SplitPartitions(targetCount int32, ts int64) []*Partition {
partitions := make([]*Partition, 0, targetCount)
partitionSize := PartitionCount / targetCount
for i := int32(0); i < targetCount; i++ {
partitionStop := (i + 1) * partitionSize
if i == targetCount-1 {
partitionStop = PartitionCount
}
partitions = append(partitions, &Partition{
RangeStart: i * partitionSize,
RangeStop: partitionStop,
RingSize: PartitionCount,
UnixTimeNs: ts,
})
}
return partitions
}
func (partition Partition) ToPbPartition() *schema_pb.Partition { func (partition Partition) ToPbPartition() *schema_pb.Partition {
return &schema_pb.Partition{ return &schema_pb.Partition{
RangeStart: partition.RangeStart, RangeStart: partition.RangeStart,

View File

@@ -3,8 +3,6 @@ package operation
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync"
"time" "time"
"github.com/seaweedfs/seaweedfs/weed/pb" "github.com/seaweedfs/seaweedfs/weed/pb"
@@ -41,118 +39,6 @@ type AssignResult struct {
Replicas []Location `json:"replicas,omitempty"` Replicas []Location `json:"replicas,omitempty"`
} }
// This is a proxy to the master server, only for assigning volume ids.
// It runs via grpc to the master server in streaming mode.
// The connection to the master would only be re-established when the last connection has error.
type AssignProxy struct {
grpcConnection *grpc.ClientConn
pool chan *singleThreadAssignProxy
}
func NewAssignProxy(masterFn GetMasterFn, grpcDialOption grpc.DialOption, concurrency int) (ap *AssignProxy, err error) {
ap = &AssignProxy{
pool: make(chan *singleThreadAssignProxy, concurrency),
}
ap.grpcConnection, err = pb.GrpcDial(context.Background(), masterFn(context.Background()).ToGrpcAddress(), true, grpcDialOption)
if err != nil {
return nil, fmt.Errorf("fail to dial %s: %v", masterFn(context.Background()).ToGrpcAddress(), err)
}
for i := 0; i < concurrency; i++ {
ap.pool <- &singleThreadAssignProxy{}
}
return ap, nil
}
func (ap *AssignProxy) Assign(primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) {
p := <-ap.pool
defer func() {
ap.pool <- p
}()
return p.doAssign(ap.grpcConnection, primaryRequest, alternativeRequests...)
}
type singleThreadAssignProxy struct {
assignClient master_pb.Seaweed_StreamAssignClient
sync.Mutex
}
func (ap *singleThreadAssignProxy) doAssign(grpcConnection *grpc.ClientConn, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) {
ap.Lock()
defer ap.Unlock()
if ap.assignClient == nil {
client := master_pb.NewSeaweedClient(grpcConnection)
ap.assignClient, err = client.StreamAssign(context.Background())
if err != nil {
ap.assignClient = nil
return nil, fmt.Errorf("fail to create stream assign client: %w", err)
}
}
var requests []*VolumeAssignRequest
requests = append(requests, primaryRequest)
requests = append(requests, alternativeRequests...)
ret = &AssignResult{}
for _, request := range requests {
if request == nil {
continue
}
req := &master_pb.AssignRequest{
Count: request.Count,
Replication: request.Replication,
Collection: request.Collection,
Ttl: request.Ttl,
DiskType: request.DiskType,
DataCenter: request.DataCenter,
Rack: request.Rack,
DataNode: request.DataNode,
WritableVolumeCount: request.WritableVolumeCount,
}
if err = ap.assignClient.Send(req); err != nil {
ap.assignClient = nil
return nil, fmt.Errorf("StreamAssignSend: %w", err)
}
resp, grpcErr := ap.assignClient.Recv()
if grpcErr != nil {
ap.assignClient = nil
return nil, grpcErr
}
if resp.Error != "" {
// StreamAssign returns transient warmup errors as in-band responses.
// Wrap them as codes.Unavailable so the caller's retry logic can
// classify them as retriable.
if strings.Contains(resp.Error, "warming up") {
return nil, status.Errorf(codes.Unavailable, "StreamAssignRecv: %s", resp.Error)
}
return nil, fmt.Errorf("StreamAssignRecv: %v", resp.Error)
}
ret.Count = resp.Count
ret.Fid = resp.Fid
ret.Url = resp.Location.Url
ret.PublicUrl = resp.Location.PublicUrl
ret.GrpcPort = int(resp.Location.GrpcPort)
ret.Error = resp.Error
ret.Auth = security.EncodedJwt(resp.Auth)
for _, r := range resp.Replicas {
ret.Replicas = append(ret.Replicas, Location{
Url: r.Url,
PublicUrl: r.PublicUrl,
DataCenter: r.DataCenter,
})
}
if ret.Count <= 0 {
continue
}
break
}
return
}
func Assign(ctx context.Context, masterFn GetMasterFn, grpcDialOption grpc.DialOption, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (*AssignResult, error) { func Assign(ctx context.Context, masterFn GetMasterFn, grpcDialOption grpc.DialOption, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (*AssignResult, error) {
var requests []*VolumeAssignRequest var requests []*VolumeAssignRequest

View File

@@ -1,70 +0,0 @@
package operation
import (
"context"
"fmt"
"testing"
"time"
"github.com/seaweedfs/seaweedfs/weed/pb"
"google.golang.org/grpc"
)
func BenchmarkWithConcurrency(b *testing.B) {
concurrencyLevels := []int{1, 10, 100, 1000}
ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress {
return pb.ServerAddress("localhost:9333")
}, grpc.WithInsecure(), 16)
for _, concurrency := range concurrencyLevels {
b.Run(
fmt.Sprintf("Concurrency-%d", concurrency),
func(b *testing.B) {
for i := 0; i < b.N; i++ {
done := make(chan struct{})
startTime := time.Now()
for j := 0; j < concurrency; j++ {
go func() {
ap.Assign(&VolumeAssignRequest{
Count: 1,
})
done <- struct{}{}
}()
}
for j := 0; j < concurrency; j++ {
<-done
}
duration := time.Since(startTime)
b.Logf("Concurrency: %d, Duration: %v", concurrency, duration)
}
},
)
}
}
func BenchmarkStreamAssign(b *testing.B) {
ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress {
return pb.ServerAddress("localhost:9333")
}, grpc.WithInsecure(), 16)
for i := 0; i < b.N; i++ {
ap.Assign(&VolumeAssignRequest{
Count: 1,
})
}
}
func BenchmarkUnaryAssign(b *testing.B) {
for i := 0; i < b.N; i++ {
Assign(context.Background(), func(_ context.Context) pb.ServerAddress {
return pb.ServerAddress("localhost:9333")
}, grpc.WithInsecure(), &VolumeAssignRequest{
Count: 1,
})
}
}

View File

@@ -93,11 +93,6 @@ func List(ctx context.Context, filerClient FilerClient, parentDirectoryPath, pre
}) })
} }
func doList(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32) (err error) {
_, err = doListWithSnapshot(ctx, filerClient, fullDirPath, prefix, fn, startFrom, inclusive, limit, 0)
return err
}
func doListWithSnapshot(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32, snapshotTsNs int64) (actualSnapshotTsNs int64, err error) { func doListWithSnapshot(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32, snapshotTsNs int64) (actualSnapshotTsNs int64, err error) {
err = filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error { err = filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
actualSnapshotTsNs, err = DoSeaweedListWithSnapshot(ctx, client, fullDirPath, prefix, fn, startFrom, inclusive, limit, snapshotTsNs) actualSnapshotTsNs, err = DoSeaweedListWithSnapshot(ctx, client, fullDirPath, prefix, fn, startFrom, inclusive, limit, snapshotTsNs)
@@ -212,26 +207,6 @@ func Exists(ctx context.Context, filerClient FilerClient, parentDirectoryPath st
return return
} }
func Touch(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, entryName string, entry *Entry) (err error) {
return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
request := &UpdateEntryRequest{
Directory: parentDirectoryPath,
Entry: entry,
}
glog.V(4).InfofCtx(ctx, "touch entry %v/%v: %v", parentDirectoryPath, entryName, request)
if err := UpdateEntry(ctx, client, request); err != nil {
glog.V(0).InfofCtx(ctx, "touch exists entry %v: %v", request, err)
return fmt.Errorf("touch exists entry %s/%s: %v", parentDirectoryPath, entryName, err)
}
return nil
})
}
func Mkdir(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, dirName string, fn func(entry *Entry)) error { func Mkdir(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, dirName string, fn func(entry *Entry)) error {
return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error { return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
return DoMkdir(ctx, client, parentDirectoryPath, dirName, fn) return DoMkdir(ctx, client, parentDirectoryPath, dirName, fn)
@@ -349,59 +324,3 @@ func DoRemoveWithResponse(ctx context.Context, client SeaweedFilerClient, parent
return resp, nil return resp, nil
} }
} }
// DoDeleteEmptyParentDirectories recursively deletes empty parent directories.
// It stops at root "/" or at stopAtPath.
// For safety, dirPath must be under stopAtPath (when stopAtPath is provided).
// The checked map tracks already-processed directories to avoid redundant work in batch operations.
func DoDeleteEmptyParentDirectories(ctx context.Context, client SeaweedFilerClient, dirPath util.FullPath, stopAtPath util.FullPath, checked map[string]bool) {
if dirPath == "/" || dirPath == stopAtPath {
return
}
// Skip if already checked (for batch delete optimization)
dirPathStr := string(dirPath)
if checked != nil {
if checked[dirPathStr] {
return
}
checked[dirPathStr] = true
}
// Safety check: if stopAtPath is provided, dirPath must be under it (root "/" allows everything)
stopStr := string(stopAtPath)
if stopAtPath != "" && stopStr != "/" && !strings.HasPrefix(dirPathStr+"/", stopStr+"/") {
glog.V(1).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: %s is not under %s, skipping", dirPath, stopAtPath)
return
}
// Check if directory is empty by listing with limit 1
isEmpty := true
err := SeaweedList(ctx, client, dirPathStr, "", func(entry *Entry, isLast bool) error {
isEmpty = false
return io.EOF // Use sentinel error to explicitly stop iteration
}, "", false, 1)
if err != nil && err != io.EOF {
glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: error checking %s: %v", dirPath, err)
return
}
if !isEmpty {
// Directory is not empty, stop checking upward
glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: directory %s is not empty, stopping cleanup", dirPath)
return
}
// Directory is empty, try to delete it
glog.V(2).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: deleting empty directory %s", dirPath)
parentDir, dirName := dirPath.DirAndName()
if err := DoRemove(ctx, client, parentDir, dirName, false, false, false, false, nil); err == nil {
// Successfully deleted, continue checking upwards
DoDeleteEmptyParentDirectories(ctx, client, util.FullPath(parentDir), stopAtPath, checked)
} else {
// Failed to delete, stop cleanup
glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: failed to delete %s: %v", dirPath, err)
}
}

View File

@@ -111,15 +111,6 @@ func BeforeEntrySerialization(chunks []*FileChunk) {
} }
} }
func EnsureFid(chunk *FileChunk) {
if chunk.Fid != nil {
return
}
if fid, err := ToFileIdObject(chunk.FileId); err == nil {
chunk.Fid = fid
}
}
func AfterEntryDeserialization(chunks []*FileChunk) { func AfterEntryDeserialization(chunks []*FileChunk) {
for _, chunk := range chunks { for _, chunk := range chunks {
@@ -309,16 +300,6 @@ func MetadataEventTouchesDirectory(event *SubscribeMetadataResponse, dir string)
MetadataEventTargetDirectory(event) == dir MetadataEventTargetDirectory(event) == dir
} }
func MetadataEventTouchesDirectoryPrefix(event *SubscribeMetadataResponse, prefix string) bool {
if strings.HasPrefix(MetadataEventSourceDirectory(event), prefix) {
return true
}
return event != nil &&
event.EventNotification != nil &&
event.EventNotification.NewEntry != nil &&
strings.HasPrefix(MetadataEventTargetDirectory(event), prefix)
}
func MetadataEventMatchesSubscription(event *SubscribeMetadataResponse, pathPrefix string, pathPrefixes []string, directories []string) bool { func MetadataEventMatchesSubscription(event *SubscribeMetadataResponse, pathPrefix string, pathPrefixes []string, directories []string) bool {
if event == nil { if event == nil {
return false return false

View File

@@ -1,87 +0,0 @@
package filer_pb
import (
"testing"
"google.golang.org/protobuf/proto"
)
func TestFileIdSize(t *testing.T) {
fileIdStr := "11745,0293434534cbb9892b"
fid, _ := ToFileIdObject(fileIdStr)
bytes, _ := proto.Marshal(fid)
println(len(fileIdStr))
println(len(bytes))
}
func TestMetadataEventMatchesSubscription(t *testing.T) {
event := &SubscribeMetadataResponse{
Directory: "/tmp",
EventNotification: &EventNotification{
OldEntry: &Entry{Name: "old-name"},
NewEntry: &Entry{Name: "new-name"},
NewParentPath: "/watched",
},
}
tests := []struct {
name string
pathPrefix string
pathPrefixes []string
directories []string
}{
{
name: "primary path prefix matches rename target",
pathPrefix: "/watched/new-name",
},
{
name: "additional path prefix matches rename target",
pathPrefix: "/data",
pathPrefixes: []string{"/watched"},
},
{
name: "directory watch matches rename target directory",
pathPrefix: "/data",
directories: []string{"/watched"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !MetadataEventMatchesSubscription(event, tt.pathPrefix, tt.pathPrefixes, tt.directories) {
t.Fatalf("MetadataEventMatchesSubscription returned false")
}
})
}
}
func TestMetadataEventTouchesDirectoryHelpers(t *testing.T) {
renameInto := &SubscribeMetadataResponse{
Directory: "/tmp",
EventNotification: &EventNotification{
OldEntry: &Entry{Name: "filer.conf"},
NewEntry: &Entry{Name: "filer.conf"},
NewParentPath: "/etc/seaweedfs",
},
}
if got := MetadataEventTargetDirectory(renameInto); got != "/etc/seaweedfs" {
t.Fatalf("MetadataEventTargetDirectory = %q, want /etc/seaweedfs", got)
}
if !MetadataEventTouchesDirectory(renameInto, "/etc/seaweedfs") {
t.Fatalf("expected rename target to touch /etc/seaweedfs")
}
renameOut := &SubscribeMetadataResponse{
Directory: "/etc/remote",
EventNotification: &EventNotification{
OldEntry: &Entry{Name: "remote.conf"},
NewEntry: &Entry{Name: "remote.conf"},
NewParentPath: "/tmp",
},
}
if !MetadataEventTouchesDirectoryPrefix(renameOut, "/etc/remote") {
t.Fatalf("expected rename source to touch /etc/remote")
}
}

View File

@@ -28,7 +28,6 @@ import (
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/master_pb" "github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/mq_pb" "github.com/seaweedfs/seaweedfs/weed/pb/mq_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
) )
const ( const (
@@ -318,18 +317,6 @@ func WithGrpcClient(streamingMode bool, signature int32, fn func(*grpc.ClientCon
} }
func ParseServerAddress(server string, deltaPort int) (newServerAddress string, err error) {
host, port, parseErr := hostAndPort(server)
if parseErr != nil {
return "", fmt.Errorf("server port parse error: %w", parseErr)
}
newPort := int(port) + deltaPort
return util.JoinHostPort(host, newPort), nil
}
func hostAndPort(address string) (host string, port uint64, err error) { func hostAndPort(address string) (host string, port uint64, err error) {
colonIndex := strings.LastIndex(address, ":") colonIndex := strings.LastIndex(address, ":")
if colonIndex < 0 { if colonIndex < 0 {
@@ -457,10 +444,3 @@ func WithOneOfGrpcFilerClients(streamingMode bool, filerAddresses []ServerAddres
return err return err
} }
func WithWorkerClient(streamingMode bool, workerAddress string, grpcDialOption grpc.DialOption, fn func(client worker_pb.WorkerServiceClient) error) error {
return WithGrpcClient(streamingMode, 0, func(grpcConnection *grpc.ClientConn) error {
client := worker_pb.NewWorkerServiceClient(grpcConnection)
return fn(client)
}, workerAddress, false, grpcDialOption)
}

View File

@@ -157,14 +157,6 @@ func (sa ServerAddresses) ToAddressMap() (addresses map[string]ServerAddress) {
return return
} }
func (sa ServerAddresses) ToAddressStrings() (addresses []string) {
parts := strings.Split(string(sa), ",")
for _, address := range parts {
addresses = append(addresses, address)
}
return
}
func ToAddressStrings(addresses []ServerAddress) []string { func ToAddressStrings(addresses []ServerAddress) []string {
var strings []string var strings []string
for _, addr := range addresses { for _, addr := range addresses {
@@ -172,20 +164,6 @@ func ToAddressStrings(addresses []ServerAddress) []string {
} }
return strings return strings
} }
func ToAddressStringsFromMap(addresses map[string]ServerAddress) []string {
var strings []string
for _, addr := range addresses {
strings = append(strings, string(addr))
}
return strings
}
func FromAddressStrings(strings []string) []ServerAddress {
var addresses []ServerAddress
for _, addr := range strings {
addresses = append(addresses, ServerAddress(addr))
}
return addresses
}
func ParseUrl(input string) (address ServerAddress, path string, err error) { func ParseUrl(input string) (address ServerAddress, path string, err error) {
if !strings.HasPrefix(input, "http://") { if !strings.HasPrefix(input, "http://") {

View File

@@ -449,58 +449,6 @@ func hasEligibleCompaction(
return len(bins) > 0, nil return len(bins) > 0, nil
} }
func countDataManifestsForRewrite(
ctx context.Context,
filerClient filer_pb.SeaweedFilerClient,
bucketName, tablePath string,
manifests []iceberg.ManifestFile,
meta table.Metadata,
predicate *partitionPredicate,
) (int64, error) {
if predicate == nil {
return countDataManifests(manifests), nil
}
specsByID := specByID(meta)
var count int64
for _, mf := range manifests {
if mf.ManifestContent() != iceberg.ManifestContentData {
continue
}
manifestData, err := loadFileByIcebergPath(ctx, filerClient, bucketName, tablePath, mf.FilePath())
if err != nil {
return 0, fmt.Errorf("read manifest %s: %w", mf.FilePath(), err)
}
entries, err := iceberg.ReadManifest(mf, bytes.NewReader(manifestData), true)
if err != nil {
return 0, fmt.Errorf("parse manifest %s: %w", mf.FilePath(), err)
}
if len(entries) == 0 {
continue
}
spec, ok := specsByID[int(mf.PartitionSpecID())]
if !ok {
continue
}
allMatch := len(entries) > 0
for _, entry := range entries {
match, err := predicate.Matches(spec, entry.DataFile().Partition())
if err != nil {
return 0, err
}
if !match {
allMatch = false
break
}
}
if allMatch {
count++
}
}
return count, nil
}
func compactionMinInputFiles(minInputFiles int64) (int, error) { func compactionMinInputFiles(minInputFiles int64) (int, error) {
// Ensure the configured value is positive and fits into the platform's int type // Ensure the configured value is positive and fits into the platform's int type
if minInputFiles <= 0 { if minInputFiles <= 0 {

View File

@@ -137,26 +137,6 @@ func mergePlanningIndexSections(index, existing *planningIndex) *planningIndex {
return index return index
} }
func buildPlanningIndex(
ctx context.Context,
filerClient filer_pb.SeaweedFilerClient,
bucketName, tablePath string,
meta table.Metadata,
config Config,
ops []string,
) (*planningIndex, error) {
currentSnap := meta.CurrentSnapshot()
if currentSnap == nil || currentSnap.ManifestList == "" {
return nil, nil
}
manifests, err := loadCurrentManifests(ctx, filerClient, bucketName, tablePath, meta)
if err != nil {
return nil, err
}
return buildPlanningIndexFromManifests(ctx, filerClient, bucketName, tablePath, meta, config, ops, manifests)
}
func buildPlanningIndexFromManifests( func buildPlanningIndexFromManifests(
ctx context.Context, ctx context.Context,
filerClient filer_pb.SeaweedFilerClient, filerClient filer_pb.SeaweedFilerClient,

Some files were not shown because too many files have changed in this diff Show More